Unverified Commit 85e1a6f3 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

Update model_loader deps and qqq quantization deps (#2220) (#2318)


Co-authored-by: default avatarHandH1998 <1335248067@qq.com>
parent 33deca81
......@@ -29,7 +29,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
......@@ -45,6 +44,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class MixtralMLP(nn.Module):
......@@ -324,7 +324,6 @@ class QuantMixtralForCausalLM(nn.Module):
self,
config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config
......
......@@ -15,7 +15,6 @@ from transformers.models.mllama.modeling_mllama import (
_prepare_aspect_ratio_attention_mask,
)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.layernorm import RMSNorm
......@@ -34,6 +33,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
......@@ -654,7 +654,6 @@ class MllamaTextModel(nn.Module):
self,
config: config_mllama.MllamaTextConfig,
quant_config: Optional[QuantizationConfig],
cache_config=None,
):
super().__init__()
self.padding_id = config.pad_token_id
......@@ -732,11 +731,10 @@ class MllamaForCausalLM(nn.Module):
self,
config: config_mllama.MllamaTextConfig,
quant_config: Optional[QuantizationConfig],
cache_config=None,
):
super().__init__()
self.vocab_size = config.vocab_size
self.model = MllamaTextModel(config, cache_config, quant_config)
self.model = MllamaTextModel(config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
......@@ -772,7 +770,6 @@ class MllamaForConditionalGeneration(nn.Module):
self,
config: config_mllama.MllamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
):
super().__init__()
self.vocab_size = config.text_config.vocab_size
......@@ -787,7 +784,6 @@ class MllamaForConditionalGeneration(nn.Module):
self.vision_model = MllamaVisionModel(config.vision_config)
self.language_model = MllamaForCausalLM(
config.text_config,
cache_config=cache_config,
quant_config=quant_config,
)
self.multi_modal_projector = nn.Linear(
......
......@@ -22,7 +22,6 @@ from torch import nn
from transformers import OlmoConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.linear import (
......@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers
......@@ -274,7 +274,6 @@ class OlmoForCausalLM(nn.Module):
def __init__(
self,
config: OlmoConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......
......@@ -312,7 +312,6 @@ class Olmo2ForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......
......@@ -34,8 +34,6 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.utils import print_warning_once
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
......@@ -48,7 +46,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers, print_warning_once
class OlmoeMoE(nn.Module):
......@@ -300,7 +299,6 @@ class OlmoeForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......
......@@ -7,8 +7,6 @@ from transformers import Phi3Config
from transformers.configuration_utils import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import make_layers
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
......@@ -27,6 +25,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers
@torch.jit.script
......@@ -235,7 +235,6 @@ class Phi3SmallDecoderLayer(nn.Module):
self,
config: PretrainedConfig,
layer_id: int,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -286,7 +285,6 @@ class Phi3SmallModel(nn.Module):
super().__init__()
self.config = config
cache_config = None
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size
)
......@@ -294,7 +292,7 @@ class Phi3SmallModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Phi3SmallDecoderLayer(
config, int(prefix.split(".")[-1]), cache_config, quant_config
config, int(prefix.split(".")[-1]), quant_config
),
prefix=f"{prefix}.layers",
)
......@@ -339,7 +337,6 @@ class Phi3SmallForCausalLM(nn.Module):
self,
config: Phi3Config,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
):
super().__init__()
......
......@@ -22,7 +22,6 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
......@@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class QWenMLP(nn.Module):
......@@ -242,7 +242,6 @@ class QWenLMHeadModel(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
):
super().__init__()
self.config = config
......
......@@ -22,7 +22,6 @@ import torch
from torch import nn
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
......@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers
Qwen2Config = None
......@@ -271,7 +271,6 @@ class Qwen2ForCausalLM(nn.Module):
self,
config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config
......
......@@ -27,7 +27,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
......@@ -48,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class Qwen2MoeMLP(nn.Module):
......@@ -158,7 +158,6 @@ class Qwen2MoeAttention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -234,7 +233,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
self,
config: PretrainedConfig,
layer_id: int,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -250,7 +248,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
......@@ -304,7 +301,6 @@ class Qwen2MoeModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -317,9 +313,7 @@ class Qwen2MoeModel(nn.Module):
)
self.layers = nn.ModuleList(
[
Qwen2MoeDecoderLayer(
config, layer_id, cache_config, quant_config=quant_config
)
Qwen2MoeDecoderLayer(config, layer_id, quant_config=quant_config)
for layer_id in range(config.num_hidden_layers)
]
)
......@@ -353,14 +347,13 @@ class Qwen2MoeForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = Qwen2MoeModel(config, cache_config, quant_config)
self.model = Qwen2MoeModel(config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
......
......@@ -30,12 +30,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
from sglang.srt.hf_transformers_utils import get_processor
......@@ -49,6 +47,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model
logger = init_logger(__name__)
......@@ -536,7 +535,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
def __init__(
self,
config: Qwen2VLConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/models/registry.py
import importlib
import logging
import pkgutil
from dataclasses import dataclass, field
from functools import lru_cache
from typing import AbstractSet, Dict, List, Optional, Tuple, Type, Union
import torch.nn as nn
logger = logging.getLogger(__name__)
@dataclass
class _ModelRegistry:
# Keyed by model_arch
models: Dict[str, Union[Type[nn.Module], str]] = field(default_factory=dict)
def get_supported_archs(self) -> AbstractSet[str]:
return self.models.keys()
def _raise_for_unsupported(self, architectures: List[str]):
all_supported_archs = self.get_supported_archs()
if any(arch in all_supported_archs for arch in architectures):
raise ValueError(
f"Model architectures {architectures} failed "
"to be inspected. Please check the logs for more details."
)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {all_supported_archs}"
)
def _try_load_model_cls(self, model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch not in self.models:
return None
return self.models[model_arch]
def _normalize_archs(
self,
architectures: Union[str, List[str]],
) -> List[str]:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
return architectures
def resolve_model_cls(
self,
architectures: Union[str, List[str]],
) -> Tuple[Type[nn.Module], str]:
architectures = self._normalize_archs(architectures)
for arch in architectures:
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
return self._raise_for_unsupported(architectures)
@lru_cache()
def import_model_classes():
model_arch_name_to_cls = {}
package_name = "sglang.srt.models"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
try:
module = importlib.import_module(name)
except Exception as e:
logger.warning(f"Ignore import error when loading {name}. " f"{e}")
continue
if hasattr(module, "EntryClass"):
entry = module.EntryClass
if isinstance(
entry, list
): # To support multiple model classes in one module
for tmp in entry:
assert (
tmp.__name__ not in model_arch_name_to_cls
), f"Duplicated model implementation for {tmp.__name__}"
model_arch_name_to_cls[tmp.__name__] = tmp
else:
assert (
entry.__name__ not in model_arch_name_to_cls
), f"Duplicated model implementation for {entry.__name__}"
model_arch_name_to_cls[entry.__name__] = entry
return model_arch_name_to_cls
ModelRegistry = _ModelRegistry(import_model_classes())
......@@ -26,7 +26,6 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.linear import (
......@@ -42,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class StablelmMLP(nn.Module):
......@@ -242,7 +242,6 @@ class StableLmForCausalLM(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config
......
......@@ -52,7 +52,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
......@@ -66,6 +65,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
......@@ -388,7 +388,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config
......
......@@ -30,7 +30,6 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
......@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.model_runner import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class XverseMLP(nn.Module):
......@@ -295,8 +295,6 @@ class XverseForCausalLM(nn.Module):
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
efficient_weight_load=False,
) -> None:
super().__init__()
self.config = config
......
......@@ -32,7 +32,6 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.fused_moe_triton import fused_moe
from sglang.srt.layers.logits_processor import LogitsProcessor
......@@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class XverseMLP(nn.Module):
......@@ -181,7 +181,6 @@ class XverseAttention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -258,7 +257,6 @@ class XverseDecoderLayer(nn.Module):
self,
config: PretrainedConfig,
layer_id: int,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -277,7 +275,6 @@ class XverseDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
if config.num_experts is not None:
......@@ -326,7 +323,6 @@ class XverseModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -339,9 +335,7 @@ class XverseModel(nn.Module):
)
self.layers = nn.ModuleList(
[
XverseDecoderLayer(
config, layer_id, cache_config, quant_config=quant_config
)
XverseDecoderLayer(config, layer_id, quant_config=quant_config)
for layer_id in range(config.num_hidden_layers)
]
)
......@@ -369,13 +363,12 @@ class XverseMoeForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = XverseModel(config, cache_config, quant_config)
self.model = XverseModel(config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
......
......@@ -18,9 +18,9 @@ from typing import Iterable, Optional, Tuple
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, LlavaConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llava import LlavaLlamaForCausalLM
......@@ -29,9 +29,8 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__(config, quant_config, cache_config)
super().__init__(config, quant_config)
self.multi_modal_projector = YiVLMultiModalProjector(self.config)
self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
......
......@@ -50,6 +50,7 @@ class ServerArgs:
served_model_name: Optional[str] = None
chat_template: Optional[str] = None
is_embedding: bool = False
revision: Optional[str] = None
# Port
host: str = "127.0.0.1"
......@@ -341,6 +342,14 @@ class ServerArgs:
action="store_true",
help="Whether to use a CausalLM as an embedding model.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
help="The specific model version to use. It can be a branch "
"name, a tag name, or a commit id. If unspecified, will use "
"the default version.",
)
# Memory and scheduling
parser.add_argument(
......
......@@ -430,16 +430,12 @@ def suppress_other_loggers():
from vllm.logger import logger as vllm_default_logger
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.config").setLevel(logging.ERROR)
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
logging.WARN
)
logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel(
logging.WARN
)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR)
warnings.filterwarnings(
"ignore", category=UserWarning, message="The given NumPy array is not writable"
......@@ -492,27 +488,6 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
pass
def monkey_patch_vllm_model_config():
from vllm.config import ModelConfig
if not hasattr(ModelConfig, "_resolve_task"):
return
def _resolve_task(
self,
task_option,
hf_config,
):
supported_tasks = {
"generate": True,
"embedding": False,
}
selected_task = "generate"
return supported_tasks, selected_task
setattr(ModelConfig, "_resolve_task", _resolve_task)
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
"""
Monkey patch the slow p2p access check in vllm.
......@@ -1041,6 +1016,11 @@ def crash_on_warnings():
return get_bool_env_var("SGLANG_IS_IN_CI")
def print_warning_once(msg: str) -> None:
# Set the stacklevel to 2 to print the caller's line info
logger.warning(msg, stacklevel=2)
def get_device_name(device_id: int = 0) -> str:
if hasattr(torch, "cuda") and torch.cuda.is_available():
return torch.cuda.get_device_name(device_id)
......@@ -1055,6 +1035,33 @@ def get_device_name(device_id: int = 0) -> str:
return torch.hpu.get_device_name(device_id)
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
major, minor = None, None
if hasattr(torch, "cuda") and torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability(device_id)
if hasattr(torch, "hip") and torch.hip.is_available():
major, minor = torch.cuda.get_device_capability(device_id)
if hasattr(torch, "xpu") and torch.xpu.is_available():
major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split(
"."
)
major, minor = int(major), int(minor)
# TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now.
# Update this once the support is available.
if hasattr(torch, "hpu") and torch.hpu.is_available():
try:
major, minor = torch.hpu.get_device_capability(device_id)
except Exception as e:
raise RuntimeError(
f"An error occurred while getting device capability of hpu: {e}."
) from e
return major, minor
sglang_lib = Library("sglang", "FRAGMENT") # noqa
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment