"vscode:/vscode.git/clone" did not exist on "5967d81782e4dc41c95815de44885703fd1d0259"
Unverified Commit 85e1a6f3 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

Update model_loader deps and qqq quantization deps (#2220) (#2318)


Co-authored-by: default avatarHandH1998 <1335248067@qq.com>
parent 33deca81
...@@ -111,8 +111,12 @@ def load_model(server_args, port_args, tp_rank): ...@@ -111,8 +111,12 @@ def load_model(server_args, port_args, tp_rank):
model_config = ModelConfig( model_config = ModelConfig(
server_args.model_path, server_args.model_path,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length, context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args, model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
) )
model_runner = ModelRunner( model_runner = ModelRunner(
model_config=model_config, model_config=model_config,
......
import logging
from typing import Optional
import torch
logger = logging.getLogger(__name__)
class DeviceConfig:
device: Optional[torch.device]
def __init__(self, device: str = "cuda") -> None:
if device in ["cuda", "xpu", "hpu"]:
self.device_type = device
else:
raise RuntimeError(f"Not supported device type: {device}")
self.device = torch.device(self.device_type)
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
import enum
import json
import logging
from dataclasses import dataclass, field
from typing import List, Optional, Union
from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__)
class LoadFormat(str, enum.Enum):
AUTO = "auto"
PT = "pt"
SAFETENSORS = "safetensors"
NPCACHE = "npcache"
DUMMY = "dummy"
SHARDED_STATE = "sharded_state"
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
@dataclass
class LoadConfig:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(model_loader_extra_config)
self._verify_load_format()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info(
"Ignoring the following patterns when downloading weights: %s",
self.ignore_patterns,
)
else:
self.ignore_patterns = ["original/**/*"]
def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str):
return
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
rocm_not_supported_load_format: List[str] = []
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f
for f in LoadFormat.__members__
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format '{load_format}' is not supported in ROCm. "
f"Supported load formats are "
f"{rocm_supported_load_format}"
)
...@@ -15,12 +15,14 @@ ...@@ -15,12 +15,14 @@
import json import json
import logging import logging
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import List, Optional from typing import List, Optional, Union
import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length from sglang.srt.hf_transformers_utils import get_config, get_context_length
from sglang.srt.utils import get_bool_env_var from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.utils import get_bool_env_var, is_hip
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -33,17 +35,22 @@ class AttentionArch(IntEnum): ...@@ -33,17 +35,22 @@ class AttentionArch(IntEnum):
class ModelConfig: class ModelConfig:
def __init__( def __init__(
self, self,
path: str, model_path: str,
trust_remote_code: bool = True, trust_remote_code: bool = True,
revision: Optional[str] = None, revision: Optional[str] = None,
context_length: Optional[int] = None, context_length: Optional[int] = None,
model_override_args: Optional[dict] = None, model_override_args: Optional[dict] = None,
is_embedding: Optional[bool] = None, is_embedding: Optional[bool] = None,
dtype: str = "auto",
quantization: Optional[str] = None,
) -> None: ) -> None:
self.model_path = model_path
self.revision = revision
self.quantization = quantization
# Parse args # Parse args
self.model_override_args = json.loads(model_override_args) self.model_override_args = json.loads(model_override_args)
self.hf_config = get_config( self.hf_config = get_config(
path, model_path,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
revision=revision, revision=revision,
model_override_args=self.model_override_args, model_override_args=self.model_override_args,
...@@ -56,6 +63,7 @@ class ModelConfig: ...@@ -56,6 +63,7 @@ class ModelConfig:
) )
self.is_multimodal = is_multimodal_model(self.hf_config.architectures) self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures) self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
# Derive context length # Derive context length
derived_context_len = get_context_length(self.hf_text_config) derived_context_len = get_context_length(self.hf_text_config)
...@@ -116,6 +124,8 @@ class ModelConfig: ...@@ -116,6 +124,8 @@ class ModelConfig:
self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.vocab_size = self.hf_text_config.vocab_size self.vocab_size = self.hf_text_config.vocab_size
self._verify_quantization()
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int: def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads.""" """Returns the total number of KV heads."""
...@@ -174,6 +184,86 @@ class ModelConfig: ...@@ -174,6 +184,86 @@ class ModelConfig:
# parallel size so each GPU has at least one KV head. # parallel size so each GPU has at least one KV head.
return max(1, total_num_kv_heads // tensor_parallel_size) return max(1, total_num_kv_heads // tensor_parallel_size)
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None:
# compressed-tensors uses a "compression_config" key
quant_cfg = getattr(self.hf_config, "compression_config", None)
return quant_cfg
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = [
"awq",
"gptq",
"fp8",
"compressed_tensors",
"compressed-tensors",
"fbgemm_fp8",
]
optimized_quantization_methods = [
"fp8",
"marlin",
"modelopt",
"gptq_marlin_24",
"gptq_marlin",
"awq_marlin",
"fbgemm_fp8",
"compressed_tensors",
"compressed-tensors",
"experts_int8",
]
if self.quantization is not None:
self.quantization = self.quantization.lower()
# Parse quantization method from the HF model config, if available.
quant_cfg = self._parse_quant_hf_config()
if quant_cfg is not None:
quant_method = quant_cfg.get("quant_method", "").lower()
# Detect which checkpoint is it
for _, method in QUANTIZATION_METHODS.items():
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization
)
if quantization_override:
quant_method = quantization_override
self.quantization = quantization_override
break
# Verify quantization configurations.
if self.quantization is None:
self.quantization = quant_method
elif self.quantization != quant_method:
raise ValueError(
"Quantization method specified in the model config "
f"({quant_method}) does not match the quantization "
f"method specified in the `quantization` argument "
f"({self.quantization})."
)
if self.quantization is not None:
if self.quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}."
)
if is_hip() and self.quantization not in rocm_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm."
)
if self.quantization not in optimized_quantization_methods:
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.",
self.quantization,
)
def get_hf_text_config(config: PretrainedConfig): def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models. """Get the "sub" config relevant to llm for multi modal models.
...@@ -183,6 +273,9 @@ def get_hf_text_config(config: PretrainedConfig): ...@@ -183,6 +273,9 @@ def get_hf_text_config(config: PretrainedConfig):
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"): if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
# We support non-hf version of llava models, so we do not want to # We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config. # read the wrong values from the unused default text_config.
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
setattr(config, "torch_dtype", torch.float16)
return config return config
if hasattr(config, "text_config"): if hasattr(config, "text_config"):
...@@ -195,6 +288,70 @@ def get_hf_text_config(config: PretrainedConfig): ...@@ -195,6 +288,70 @@ def get_hf_text_config(config: PretrainedConfig):
return config return config
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
"float": torch.float32,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _get_and_verify_dtype(
config: PretrainedConfig,
dtype: Union[str, torch.dtype],
) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
if config_dtype is None:
config_dtype = torch.float32
if isinstance(dtype, str):
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
if config.model_type == "gemma2":
logger.info(
"For Gemma 2, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16."
)
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
elif isinstance(dtype, torch.dtype):
torch_dtype = dtype
else:
raise ValueError(f"Unknown dtype: {dtype}")
# Verify the dtype.
if torch_dtype != config_dtype:
if torch_dtype == torch.float32:
# Upcasting to float32 is allowed.
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
pass
elif config_dtype == torch.float32:
# Downcasting from float32 to float16 or bfloat16 is allowed.
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
pass
else:
# Casting between float16 and bfloat16 is allowed with a warning.
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
return torch_dtype
def is_generation_model(model_architectures: List[str], is_embedding: bool = False): def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
# We have two ways to determine whether a model is a generative model. # We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue # 1. Check the model architectue
......
...@@ -121,13 +121,10 @@ class Qwen2VLConfig(PretrainedConfig): ...@@ -121,13 +121,10 @@ class Qwen2VLConfig(PretrainedConfig):
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
# NOTE: the following section from original transformers config # NOTE(HandH1998): This is necessary for configuring the `rope_type`` of qwen2vl models after removing dependencies on vllm.
# for Qwen2-VL is commented out to address rope config loading issue if self.rope_scaling is not None and "type" in self.rope_scaling:
# if self.rope_scaling["type"] == "mrope":
# if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["type"] = "default"
# if self.rope_scaling["type"] == "mrope": self.rope_scaling["rope_type"] = self.rope_scaling["type"]
# self.rope_scaling["type"] = "default"
# self.rope_scaling["rope_type"] = self.rope_scaling["type"]
# rope_config_validation(self)
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
...@@ -75,6 +75,8 @@ def get_config( ...@@ -75,6 +75,8 @@ def get_config(
if config.model_type in _CONFIG_REGISTRY: if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type] config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision) config = config_class.from_pretrained(model, revision=revision)
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
setattr(config, "_name_or_path", model)
if model_override_args: if model_override_args:
config.update(model_override_args) config.update(model_override_args)
......
...@@ -42,6 +42,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ ...@@ -42,6 +42,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"Fp8LinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "MarlinLinearMethod",
"GPTQLinearMethod", "GPTQLinearMethod",
"QQQLinearMethod",
] ]
......
...@@ -31,7 +31,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -31,7 +31,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -40,6 +39,7 @@ from sglang.srt.layers.linear import ( ...@@ -40,6 +39,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.loader import DefaultModelLoader
class BaseLayerWithLoRA(nn.Module): class BaseLayerWithLoRA(nn.Module):
......
...@@ -147,9 +147,12 @@ class Scheduler: ...@@ -147,9 +147,12 @@ class Scheduler:
self.model_config = ModelConfig( self.model_config = ModelConfig(
server_args.model_path, server_args.model_path,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length, context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args, model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding, is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
) )
self.is_generation = self.model_config.is_generation self.is_generation = self.model_config.is_generation
......
...@@ -109,9 +109,12 @@ class TokenizerManager: ...@@ -109,9 +109,12 @@ class TokenizerManager:
self.model_config = ModelConfig( self.model_config = ModelConfig(
server_args.model_path, server_args.model_path,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length, context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args, model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding, is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
) )
self.is_generation = self.model_config.is_generation self.is_generation = self.model_config.is_generation
......
...@@ -52,9 +52,12 @@ class TpModelWorker: ...@@ -52,9 +52,12 @@ class TpModelWorker:
self.model_config = ModelConfig( self.model_config = ModelConfig(
server_args.model_path, server_args.model_path,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length, context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args, model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding, is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
) )
self.model_runner = ModelRunner( self.model_runner = ModelRunner(
model_config=self.model_config, model_config=self.model_config,
......
...@@ -14,22 +14,12 @@ ...@@ -14,22 +14,12 @@
"""ModelRunner runs the forward passes of the models.""" """ModelRunner runs the forward passes of the models."""
import gc import gc
import importlib
import importlib.resources
import inspect
import json import json
import logging import logging
import pkgutil from typing import Optional
import time
from functools import lru_cache
from tokenize import tabsize
from typing import Any, Optional, Type, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import ( from vllm.distributed import (
get_tp_group, get_tp_group,
init_distributed_environment, init_distributed_environment,
...@@ -37,9 +27,9 @@ from vllm.distributed import ( ...@@ -37,9 +27,9 @@ from vllm.distributed import (
set_custom_all_reduce, set_custom_all_reduce,
) )
from vllm.distributed.parallel_state import in_the_same_node_as from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
...@@ -56,16 +46,15 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -56,16 +46,15 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool, ReqToTokenPool,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
crash_on_warnings,
enable_show_time_cost, enable_show_time_cost,
get_available_gpu_memory, get_available_gpu_memory,
init_custom_process_group, init_custom_process_group,
is_hip, is_hip,
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check, monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes, set_cpu_offload_max_bytes,
) )
...@@ -228,49 +217,6 @@ class ModelRunner: ...@@ -228,49 +217,6 @@ class ModelRunner:
return min_per_gpu_memory return min_per_gpu_memory
def setup_model(self):
try:
from vllm.config import VllmConfig
vllm_config = VllmConfig()
vllm_config.model_config = self.vllm_model_config
vllm_config.load_config = self.load_config
vllm_config.device_config = DeviceConfig(self.device)
vllm_config.quant_config = VllmConfig._get_quantization_config(
vllm_config.model_config, vllm_config.load_config
)
return get_model(vllm_config=vllm_config)
except ImportError:
pass
return get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
parallel_config=None,
scheduler_config=None,
lora_config=None,
cache_config=None,
)
def get_model_config_params(self):
sig = inspect.signature(VllmModelConfig.__init__)
params = {
"model": self.server_args.model_path,
"quantization": self.server_args.quantization,
"tokenizer": None,
"tokenizer_mode": None,
"trust_remote_code": self.server_args.trust_remote_code,
"dtype": self.server_args.dtype,
"seed": self.server_args.random_seed,
"skip_tokenizer_init": True,
}
if "task" in sig.parameters:
params["task"] = ""
return params
def load_model(self): def load_model(self):
logger.info( logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
...@@ -284,6 +230,7 @@ class ModelRunner: ...@@ -284,6 +230,7 @@ class ModelRunner:
"Compute capability below sm80. Use float16 due to lack of bfloat16 support." "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
) )
self.server_args.dtype = "float16" self.server_args.dtype = "float16"
self.model_config.dtype = torch.float16
if torch.cuda.get_device_capability()[1] < 5: if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.") raise RuntimeError("SGLang only supports sm75 and above.")
...@@ -292,23 +239,21 @@ class ModelRunner: ...@@ -292,23 +239,21 @@ class ModelRunner:
load_format=self.server_args.load_format, load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir, download_dir=self.server_args.download_dir,
) )
monkey_patch_vllm_model_config()
if self.server_args.load_format == "gguf": if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config() monkey_patch_vllm_gguf_config()
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params()) self.model = get_model(
if self.model_config.model_override_args is not None: model_config=self.model_config,
self.vllm_model_config.hf_config.update( load_config=self.load_config,
self.model_config.model_override_args device_config=DeviceConfig(self.device),
) )
self.model = self.setup_model()
self.sliding_window_size = ( self.sliding_window_size = (
self.model.get_attention_sliding_window_size() self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size") if hasattr(self.model, "get_attention_sliding_window_size")
else None else None
) )
self.dtype = self.vllm_model_config.dtype self.dtype = self.model_config.dtype
logger.info( logger.info(
f"Load weight end. " f"Load weight end. "
...@@ -319,12 +264,12 @@ class ModelRunner: ...@@ -319,12 +264,12 @@ class ModelRunner:
def update_weights_from_disk(self, model_path: str, load_format: str): def update_weights_from_disk(self, model_path: str, load_format: str):
"""Update engine weights online from disk.""" """Update engine weights online from disk."""
from vllm.model_executor.model_loader.loader import ( from sglang.srt.model_loader.loader import (
DefaultModelLoader, DefaultModelLoader,
device_loading_context, device_loading_context,
get_model_loader, get_model_loader,
) )
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.utils import set_default_torch_dtype
logger.info( logger.info(
f"Update engine weights online from disk begin. " f"Update engine weights online from disk begin. "
...@@ -332,15 +277,7 @@ class ModelRunner: ...@@ -332,15 +277,7 @@ class ModelRunner:
) )
target_device = torch.device(self.device) target_device = torch.device(self.device)
self.model_config.model_path = model_path
try:
model_config_params = self.get_model_config_params()
model_config_params["model"] = model_path
vllm_model_config = VllmModelConfig(**model_config_params)
except Exception as e:
message = f"Failed to load model config: {e}."
return False, message
load_config = LoadConfig(load_format=load_format) load_config = LoadConfig(load_format=load_format)
# Only support vllm DefaultModelLoader for now # Only support vllm DefaultModelLoader for now
...@@ -352,7 +289,7 @@ class ModelRunner: ...@@ -352,7 +289,7 @@ class ModelRunner:
def get_weight_iter(config): def get_weight_iter(config):
iter = loader._get_weights_iterator( iter = loader._get_weights_iterator(
DefaultModelLoader.Source( DefaultModelLoader.Source(
config.model, config.model_path,
revision=config.revision, revision=config.revision,
fall_back_to_pt=getattr( fall_back_to_pt=getattr(
self.model, "fall_back_to_pt_during_load", True self.model, "fall_back_to_pt_during_load", True
...@@ -370,9 +307,9 @@ class ModelRunner: ...@@ -370,9 +307,9 @@ class ModelRunner:
quant_method.process_weights_after_loading(module) quant_method.process_weights_after_loading(module)
return model return model
with set_default_torch_dtype(vllm_model_config.dtype): with set_default_torch_dtype(self.model_config.dtype):
try: try:
iter = get_weight_iter(vllm_model_config) iter = get_weight_iter(self.model_config)
except Exception as e: except Exception as e:
message = f"Failed to get weights iterator: {e}." message = f"Failed to get weights iterator: {e}."
return False, message return False, message
...@@ -384,16 +321,14 @@ class ModelRunner: ...@@ -384,16 +321,14 @@ class ModelRunner:
) )
del iter del iter
gc.collect() gc.collect()
iter = get_weight_iter(self.vllm_model_config) iter = get_weight_iter(self.model_config)
self.model = model_load_weights(self.model, iter) self.model = model_load_weights(self.model, iter)
return False, message return False, message
self.model = model self.model = model
self.server_args.model_path = model_path self.server_args.model_path = model_path
self.server_args.load_format = load_format self.server_args.load_format = load_format
self.vllm_model_config = vllm_model_config
self.load_config = load_config self.load_config = load_config
self.model_config.path = model_path
logger.info("Update weights end.") logger.info("Update weights end.")
return True, "Succeeded to update model weights." return True, "Succeeded to update model weights."
...@@ -794,55 +729,3 @@ class ModelRunner: ...@@ -794,55 +729,3 @@ class ModelRunner:
if rope_scaling is None: if rope_scaling is None:
return False return False
return rope_scaling.get("type", None) == "mrope" return rope_scaling.get("type", None) == "mrope"
@lru_cache()
def import_model_classes():
model_arch_name_to_cls = {}
package_name = "sglang.srt.models"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
try:
module = importlib.import_module(name)
except Exception as e:
logger.warning(f"Ignore import error when loading {name}. {e}")
if crash_on_warnings():
raise ValueError(f"Ignore import error when loading {name}. {e}")
continue
if hasattr(module, "EntryClass"):
entry = module.EntryClass
if isinstance(
entry, list
): # To support multiple model classes in one module
for tmp in entry:
assert (
tmp.__name__ not in model_arch_name_to_cls
), f"Duplicated model implementation for {tmp.__name__}"
model_arch_name_to_cls[tmp.__name__] = tmp
else:
assert (
entry.__name__ not in model_arch_name_to_cls
), f"Duplicated model implementation for {entry.__name__}"
model_arch_name_to_cls[entry.__name__] = entry
return model_arch_name_to_cls
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
model_arch_name_to_cls = import_model_classes()
if model_arch not in model_arch_name_to_cls:
raise ValueError(
f"Unsupported architectures: {model_arch}. "
f"Supported list: {list(model_arch_name_to_cls.keys())}"
)
return model_arch_name_to_cls[model_arch]
# Monkey patch model loader
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
from torch import nn
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
from sglang.srt.model_loader.utils import (
get_architecture_class_name,
get_model_architecture,
)
def get_model(
*,
model_config: ModelConfig,
load_config: LoadConfig,
device_config: DeviceConfig,
) -> nn.Module:
loader = get_model_loader(load_config)
return loader.load_model(
model_config=model_config,
device_config=device_config,
)
__all__ = [
"get_model",
"get_model_loader",
"BaseModelLoader",
"get_architecture_class_name",
"get_model_architecture",
]
This diff is collapsed.
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/utils.py
"""Utilities for selecting and loading models."""
import contextlib
from typing import Tuple, Type
import torch
from torch import nn
from sglang.srt.configs.model_config import ModelConfig
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
from sglang.srt.models.registry import ModelRegistry
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"]
if (
model_config.quantization is not None
and model_config.quantization not in mixtral_supported
and "MixtralForCausalLM" in architectures
):
architectures = ["QuantMixtralForCausalLM"]
return ModelRegistry.resolve_model_cls(architectures)
def get_architecture_class_name(model_config: ModelConfig) -> str:
return get_model_architecture(model_config)[1]
This diff is collapsed.
...@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import ( ...@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
...@@ -329,7 +329,6 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -329,7 +329,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
position_embedding: str, position_embedding: str,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -404,13 +403,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): ...@@ -404,13 +403,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__( def __init__(
self, self,
config, config,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
if config.hidden_size == 4096: # baichuan2 7b if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", cache_config, quant_config) super().__init__(config, "ROPE", quant_config)
else: # baichuan 13b, baichuan2 13b else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", cache_config, quant_config) super().__init__(config, "ALIBI", quant_config)
EntryClass = [BaichuanForCausalLM] EntryClass = [BaichuanForCausalLM]
...@@ -23,7 +23,6 @@ from torch import nn ...@@ -23,7 +23,6 @@ from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
...@@ -41,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -41,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
LoraConfig = None LoraConfig = None
...@@ -50,7 +50,6 @@ class GLMAttention(nn.Module): ...@@ -50,7 +50,6 @@ class GLMAttention(nn.Module):
self, self,
config, config,
layer_id: int = 0, layer_id: int = 0,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -186,7 +185,6 @@ class GLMBlock(nn.Module): ...@@ -186,7 +185,6 @@ class GLMBlock(nn.Module):
self, self,
config, config,
layer_id: int, layer_id: int,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -203,7 +201,7 @@ class GLMBlock(nn.Module): ...@@ -203,7 +201,7 @@ class GLMBlock(nn.Module):
) )
# Self attention. # Self attention.
self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config) self.self_attention = GLMAttention(config, layer_id, quant_config)
self.hidden_dropout = config.hidden_dropout self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output # Layernorm on the attention output
...@@ -258,7 +256,6 @@ class GLMTransformer(nn.Module): ...@@ -258,7 +256,6 @@ class GLMTransformer(nn.Module):
def __init__( def __init__(
self, self,
config, config,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -269,10 +266,7 @@ class GLMTransformer(nn.Module): ...@@ -269,10 +266,7 @@ class GLMTransformer(nn.Module):
# Transformer layers. # Transformer layers.
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [GLMBlock(config, i, quant_config) for i in range(self.num_layers)]
GLMBlock(config, i, cache_config, quant_config)
for i in range(self.num_layers)
]
) )
if self.post_layer_norm: if self.post_layer_norm:
...@@ -306,7 +300,6 @@ class ChatGLMM(nn.Module): ...@@ -306,7 +300,6 @@ class ChatGLMM(nn.Module):
def __init__( def __init__(
self, self,
config, config,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -318,7 +311,7 @@ class ChatGLMM(nn.Module): ...@@ -318,7 +311,7 @@ class ChatGLMM(nn.Module):
self.num_layers = config.num_layers self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, cache_config, quant_config) self.encoder = GLMTransformer(config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size) self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
...@@ -357,15 +350,13 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -357,15 +350,13 @@ class ChatGLMForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: ChatGLMConfig, config: ChatGLMConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoraConfig] = None,
): ):
super().__init__() super().__init__()
self.config: ChatGLMConfig = config self.config: ChatGLMConfig = config
self.quant_config = quant_config self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
self.transformer = ChatGLMM(config, cache_config, quant_config) self.transformer = ChatGLMM(config, quant_config)
self.lm_head = self.transformer.output_layer self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -49,7 +49,6 @@ from vllm.distributed import ( ...@@ -49,7 +49,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
...@@ -62,6 +61,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig ...@@ -62,6 +61,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
...@@ -318,7 +318,6 @@ class CohereForCausalLM(nn.Module): ...@@ -318,7 +318,6 @@ class CohereForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
......
...@@ -25,7 +25,6 @@ from vllm.distributed import ( ...@@ -25,7 +25,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.fused_moe_triton import fused_moe from sglang.srt.layers.fused_moe_triton import fused_moe
...@@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
...@@ -366,7 +366,6 @@ class DbrxForCausalLM(nn.Module): ...@@ -366,7 +366,6 @@ class DbrxForCausalLM(nn.Module):
self, self,
config: DbrxConfig, config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment