Unverified Commit 85e1a6f3 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

Update model_loader deps and qqq quantization deps (#2220) (#2318)


Co-authored-by: default avatarHandH1998 <1335248067@qq.com>
parent 33deca81
...@@ -111,8 +111,12 @@ def load_model(server_args, port_args, tp_rank): ...@@ -111,8 +111,12 @@ def load_model(server_args, port_args, tp_rank):
model_config = ModelConfig( model_config = ModelConfig(
server_args.model_path, server_args.model_path,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length, context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args, model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
) )
model_runner = ModelRunner( model_runner = ModelRunner(
model_config=model_config, model_config=model_config,
......
import logging
from typing import Optional
import torch
logger = logging.getLogger(__name__)
class DeviceConfig:
device: Optional[torch.device]
def __init__(self, device: str = "cuda") -> None:
if device in ["cuda", "xpu", "hpu"]:
self.device_type = device
else:
raise RuntimeError(f"Not supported device type: {device}")
self.device = torch.device(self.device_type)
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
import enum
import json
import logging
from dataclasses import dataclass, field
from typing import List, Optional, Union
from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__)
class LoadFormat(str, enum.Enum):
AUTO = "auto"
PT = "pt"
SAFETENSORS = "safetensors"
NPCACHE = "npcache"
DUMMY = "dummy"
SHARDED_STATE = "sharded_state"
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
@dataclass
class LoadConfig:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(model_loader_extra_config)
self._verify_load_format()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info(
"Ignoring the following patterns when downloading weights: %s",
self.ignore_patterns,
)
else:
self.ignore_patterns = ["original/**/*"]
def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str):
return
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
rocm_not_supported_load_format: List[str] = []
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f
for f in LoadFormat.__members__
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format '{load_format}' is not supported in ROCm. "
f"Supported load formats are "
f"{rocm_supported_load_format}"
)
...@@ -15,12 +15,14 @@ ...@@ -15,12 +15,14 @@
import json import json
import logging import logging
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import List, Optional from typing import List, Optional, Union
import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length from sglang.srt.hf_transformers_utils import get_config, get_context_length
from sglang.srt.utils import get_bool_env_var from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.utils import get_bool_env_var, is_hip
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -33,17 +35,22 @@ class AttentionArch(IntEnum): ...@@ -33,17 +35,22 @@ class AttentionArch(IntEnum):
class ModelConfig: class ModelConfig:
def __init__( def __init__(
self, self,
path: str, model_path: str,
trust_remote_code: bool = True, trust_remote_code: bool = True,
revision: Optional[str] = None, revision: Optional[str] = None,
context_length: Optional[int] = None, context_length: Optional[int] = None,
model_override_args: Optional[dict] = None, model_override_args: Optional[dict] = None,
is_embedding: Optional[bool] = None, is_embedding: Optional[bool] = None,
dtype: str = "auto",
quantization: Optional[str] = None,
) -> None: ) -> None:
self.model_path = model_path
self.revision = revision
self.quantization = quantization
# Parse args # Parse args
self.model_override_args = json.loads(model_override_args) self.model_override_args = json.loads(model_override_args)
self.hf_config = get_config( self.hf_config = get_config(
path, model_path,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
revision=revision, revision=revision,
model_override_args=self.model_override_args, model_override_args=self.model_override_args,
...@@ -56,6 +63,7 @@ class ModelConfig: ...@@ -56,6 +63,7 @@ class ModelConfig:
) )
self.is_multimodal = is_multimodal_model(self.hf_config.architectures) self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures) self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
# Derive context length # Derive context length
derived_context_len = get_context_length(self.hf_text_config) derived_context_len = get_context_length(self.hf_text_config)
...@@ -116,6 +124,8 @@ class ModelConfig: ...@@ -116,6 +124,8 @@ class ModelConfig:
self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.vocab_size = self.hf_text_config.vocab_size self.vocab_size = self.hf_text_config.vocab_size
self._verify_quantization()
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int: def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads.""" """Returns the total number of KV heads."""
...@@ -174,6 +184,86 @@ class ModelConfig: ...@@ -174,6 +184,86 @@ class ModelConfig:
# parallel size so each GPU has at least one KV head. # parallel size so each GPU has at least one KV head.
return max(1, total_num_kv_heads // tensor_parallel_size) return max(1, total_num_kv_heads // tensor_parallel_size)
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None:
# compressed-tensors uses a "compression_config" key
quant_cfg = getattr(self.hf_config, "compression_config", None)
return quant_cfg
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = [
"awq",
"gptq",
"fp8",
"compressed_tensors",
"compressed-tensors",
"fbgemm_fp8",
]
optimized_quantization_methods = [
"fp8",
"marlin",
"modelopt",
"gptq_marlin_24",
"gptq_marlin",
"awq_marlin",
"fbgemm_fp8",
"compressed_tensors",
"compressed-tensors",
"experts_int8",
]
if self.quantization is not None:
self.quantization = self.quantization.lower()
# Parse quantization method from the HF model config, if available.
quant_cfg = self._parse_quant_hf_config()
if quant_cfg is not None:
quant_method = quant_cfg.get("quant_method", "").lower()
# Detect which checkpoint is it
for _, method in QUANTIZATION_METHODS.items():
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization
)
if quantization_override:
quant_method = quantization_override
self.quantization = quantization_override
break
# Verify quantization configurations.
if self.quantization is None:
self.quantization = quant_method
elif self.quantization != quant_method:
raise ValueError(
"Quantization method specified in the model config "
f"({quant_method}) does not match the quantization "
f"method specified in the `quantization` argument "
f"({self.quantization})."
)
if self.quantization is not None:
if self.quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}."
)
if is_hip() and self.quantization not in rocm_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm."
)
if self.quantization not in optimized_quantization_methods:
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.",
self.quantization,
)
def get_hf_text_config(config: PretrainedConfig): def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models. """Get the "sub" config relevant to llm for multi modal models.
...@@ -183,6 +273,9 @@ def get_hf_text_config(config: PretrainedConfig): ...@@ -183,6 +273,9 @@ def get_hf_text_config(config: PretrainedConfig):
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"): if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
# We support non-hf version of llava models, so we do not want to # We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config. # read the wrong values from the unused default text_config.
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
setattr(config, "torch_dtype", torch.float16)
return config return config
if hasattr(config, "text_config"): if hasattr(config, "text_config"):
...@@ -195,6 +288,70 @@ def get_hf_text_config(config: PretrainedConfig): ...@@ -195,6 +288,70 @@ def get_hf_text_config(config: PretrainedConfig):
return config return config
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
"float": torch.float32,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _get_and_verify_dtype(
config: PretrainedConfig,
dtype: Union[str, torch.dtype],
) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
if config_dtype is None:
config_dtype = torch.float32
if isinstance(dtype, str):
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
if config.model_type == "gemma2":
logger.info(
"For Gemma 2, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16."
)
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
elif isinstance(dtype, torch.dtype):
torch_dtype = dtype
else:
raise ValueError(f"Unknown dtype: {dtype}")
# Verify the dtype.
if torch_dtype != config_dtype:
if torch_dtype == torch.float32:
# Upcasting to float32 is allowed.
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
pass
elif config_dtype == torch.float32:
# Downcasting from float32 to float16 or bfloat16 is allowed.
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
pass
else:
# Casting between float16 and bfloat16 is allowed with a warning.
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
return torch_dtype
def is_generation_model(model_architectures: List[str], is_embedding: bool = False): def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
# We have two ways to determine whether a model is a generative model. # We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue # 1. Check the model architectue
......
...@@ -121,13 +121,10 @@ class Qwen2VLConfig(PretrainedConfig): ...@@ -121,13 +121,10 @@ class Qwen2VLConfig(PretrainedConfig):
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
# NOTE: the following section from original transformers config # NOTE(HandH1998): This is necessary for configuring the `rope_type`` of qwen2vl models after removing dependencies on vllm.
# for Qwen2-VL is commented out to address rope config loading issue if self.rope_scaling is not None and "type" in self.rope_scaling:
# if self.rope_scaling["type"] == "mrope":
# if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["type"] = "default"
# if self.rope_scaling["type"] == "mrope": self.rope_scaling["rope_type"] = self.rope_scaling["type"]
# self.rope_scaling["type"] = "default"
# self.rope_scaling["rope_type"] = self.rope_scaling["type"]
# rope_config_validation(self)
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
...@@ -75,6 +75,8 @@ def get_config( ...@@ -75,6 +75,8 @@ def get_config(
if config.model_type in _CONFIG_REGISTRY: if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type] config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision) config = config_class.from_pretrained(model, revision=revision)
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
setattr(config, "_name_or_path", model)
if model_override_args: if model_override_args:
config.update(model_override_args) config.update(model_override_args)
......
...@@ -42,6 +42,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ ...@@ -42,6 +42,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"Fp8LinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "MarlinLinearMethod",
"GPTQLinearMethod", "GPTQLinearMethod",
"QQQLinearMethod",
] ]
......
...@@ -31,7 +31,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -31,7 +31,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -40,6 +39,7 @@ from sglang.srt.layers.linear import ( ...@@ -40,6 +39,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.loader import DefaultModelLoader
class BaseLayerWithLoRA(nn.Module): class BaseLayerWithLoRA(nn.Module):
......
...@@ -147,9 +147,12 @@ class Scheduler: ...@@ -147,9 +147,12 @@ class Scheduler:
self.model_config = ModelConfig( self.model_config = ModelConfig(
server_args.model_path, server_args.model_path,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length, context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args, model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding, is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
) )
self.is_generation = self.model_config.is_generation self.is_generation = self.model_config.is_generation
......
...@@ -109,9 +109,12 @@ class TokenizerManager: ...@@ -109,9 +109,12 @@ class TokenizerManager:
self.model_config = ModelConfig( self.model_config = ModelConfig(
server_args.model_path, server_args.model_path,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length, context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args, model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding, is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
) )
self.is_generation = self.model_config.is_generation self.is_generation = self.model_config.is_generation
......
...@@ -52,9 +52,12 @@ class TpModelWorker: ...@@ -52,9 +52,12 @@ class TpModelWorker:
self.model_config = ModelConfig( self.model_config = ModelConfig(
server_args.model_path, server_args.model_path,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length, context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args, model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding, is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
) )
self.model_runner = ModelRunner( self.model_runner = ModelRunner(
model_config=self.model_config, model_config=self.model_config,
......
...@@ -14,22 +14,12 @@ ...@@ -14,22 +14,12 @@
"""ModelRunner runs the forward passes of the models.""" """ModelRunner runs the forward passes of the models."""
import gc import gc
import importlib
import importlib.resources
import inspect
import json import json
import logging import logging
import pkgutil from typing import Optional
import time
from functools import lru_cache
from tokenize import tabsize
from typing import Any, Optional, Type, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import ( from vllm.distributed import (
get_tp_group, get_tp_group,
init_distributed_environment, init_distributed_environment,
...@@ -37,9 +27,9 @@ from vllm.distributed import ( ...@@ -37,9 +27,9 @@ from vllm.distributed import (
set_custom_all_reduce, set_custom_all_reduce,
) )
from vllm.distributed.parallel_state import in_the_same_node_as from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
...@@ -56,16 +46,15 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -56,16 +46,15 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool, ReqToTokenPool,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
crash_on_warnings,
enable_show_time_cost, enable_show_time_cost,
get_available_gpu_memory, get_available_gpu_memory,
init_custom_process_group, init_custom_process_group,
is_hip, is_hip,
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check, monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes, set_cpu_offload_max_bytes,
) )
...@@ -228,49 +217,6 @@ class ModelRunner: ...@@ -228,49 +217,6 @@ class ModelRunner:
return min_per_gpu_memory return min_per_gpu_memory
def setup_model(self):
try:
from vllm.config import VllmConfig
vllm_config = VllmConfig()
vllm_config.model_config = self.vllm_model_config
vllm_config.load_config = self.load_config
vllm_config.device_config = DeviceConfig(self.device)
vllm_config.quant_config = VllmConfig._get_quantization_config(
vllm_config.model_config, vllm_config.load_config
)
return get_model(vllm_config=vllm_config)
except ImportError:
pass
return get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
parallel_config=None,
scheduler_config=None,
lora_config=None,
cache_config=None,
)
def get_model_config_params(self):
sig = inspect.signature(VllmModelConfig.__init__)
params = {
"model": self.server_args.model_path,
"quantization": self.server_args.quantization,
"tokenizer": None,
"tokenizer_mode": None,
"trust_remote_code": self.server_args.trust_remote_code,
"dtype": self.server_args.dtype,
"seed": self.server_args.random_seed,
"skip_tokenizer_init": True,
}
if "task" in sig.parameters:
params["task"] = ""
return params
def load_model(self): def load_model(self):
logger.info( logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
...@@ -284,6 +230,7 @@ class ModelRunner: ...@@ -284,6 +230,7 @@ class ModelRunner:
"Compute capability below sm80. Use float16 due to lack of bfloat16 support." "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
) )
self.server_args.dtype = "float16" self.server_args.dtype = "float16"
self.model_config.dtype = torch.float16
if torch.cuda.get_device_capability()[1] < 5: if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.") raise RuntimeError("SGLang only supports sm75 and above.")
...@@ -292,23 +239,21 @@ class ModelRunner: ...@@ -292,23 +239,21 @@ class ModelRunner:
load_format=self.server_args.load_format, load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir, download_dir=self.server_args.download_dir,
) )
monkey_patch_vllm_model_config()
if self.server_args.load_format == "gguf": if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config() monkey_patch_vllm_gguf_config()
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params()) self.model = get_model(
if self.model_config.model_override_args is not None: model_config=self.model_config,
self.vllm_model_config.hf_config.update( load_config=self.load_config,
self.model_config.model_override_args device_config=DeviceConfig(self.device),
) )
self.model = self.setup_model()
self.sliding_window_size = ( self.sliding_window_size = (
self.model.get_attention_sliding_window_size() self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size") if hasattr(self.model, "get_attention_sliding_window_size")
else None else None
) )
self.dtype = self.vllm_model_config.dtype self.dtype = self.model_config.dtype
logger.info( logger.info(
f"Load weight end. " f"Load weight end. "
...@@ -319,12 +264,12 @@ class ModelRunner: ...@@ -319,12 +264,12 @@ class ModelRunner:
def update_weights_from_disk(self, model_path: str, load_format: str): def update_weights_from_disk(self, model_path: str, load_format: str):
"""Update engine weights online from disk.""" """Update engine weights online from disk."""
from vllm.model_executor.model_loader.loader import ( from sglang.srt.model_loader.loader import (
DefaultModelLoader, DefaultModelLoader,
device_loading_context, device_loading_context,
get_model_loader, get_model_loader,
) )
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.utils import set_default_torch_dtype
logger.info( logger.info(
f"Update engine weights online from disk begin. " f"Update engine weights online from disk begin. "
...@@ -332,15 +277,7 @@ class ModelRunner: ...@@ -332,15 +277,7 @@ class ModelRunner:
) )
target_device = torch.device(self.device) target_device = torch.device(self.device)
self.model_config.model_path = model_path
try:
model_config_params = self.get_model_config_params()
model_config_params["model"] = model_path
vllm_model_config = VllmModelConfig(**model_config_params)
except Exception as e:
message = f"Failed to load model config: {e}."
return False, message
load_config = LoadConfig(load_format=load_format) load_config = LoadConfig(load_format=load_format)
# Only support vllm DefaultModelLoader for now # Only support vllm DefaultModelLoader for now
...@@ -352,7 +289,7 @@ class ModelRunner: ...@@ -352,7 +289,7 @@ class ModelRunner:
def get_weight_iter(config): def get_weight_iter(config):
iter = loader._get_weights_iterator( iter = loader._get_weights_iterator(
DefaultModelLoader.Source( DefaultModelLoader.Source(
config.model, config.model_path,
revision=config.revision, revision=config.revision,
fall_back_to_pt=getattr( fall_back_to_pt=getattr(
self.model, "fall_back_to_pt_during_load", True self.model, "fall_back_to_pt_during_load", True
...@@ -370,9 +307,9 @@ class ModelRunner: ...@@ -370,9 +307,9 @@ class ModelRunner:
quant_method.process_weights_after_loading(module) quant_method.process_weights_after_loading(module)
return model return model
with set_default_torch_dtype(vllm_model_config.dtype): with set_default_torch_dtype(self.model_config.dtype):
try: try:
iter = get_weight_iter(vllm_model_config) iter = get_weight_iter(self.model_config)
except Exception as e: except Exception as e:
message = f"Failed to get weights iterator: {e}." message = f"Failed to get weights iterator: {e}."
return False, message return False, message
...@@ -384,16 +321,14 @@ class ModelRunner: ...@@ -384,16 +321,14 @@ class ModelRunner:
) )
del iter del iter
gc.collect() gc.collect()
iter = get_weight_iter(self.vllm_model_config) iter = get_weight_iter(self.model_config)
self.model = model_load_weights(self.model, iter) self.model = model_load_weights(self.model, iter)
return False, message return False, message
self.model = model self.model = model
self.server_args.model_path = model_path self.server_args.model_path = model_path
self.server_args.load_format = load_format self.server_args.load_format = load_format
self.vllm_model_config = vllm_model_config
self.load_config = load_config self.load_config = load_config
self.model_config.path = model_path
logger.info("Update weights end.") logger.info("Update weights end.")
return True, "Succeeded to update model weights." return True, "Succeeded to update model weights."
...@@ -794,55 +729,3 @@ class ModelRunner: ...@@ -794,55 +729,3 @@ class ModelRunner:
if rope_scaling is None: if rope_scaling is None:
return False return False
return rope_scaling.get("type", None) == "mrope" return rope_scaling.get("type", None) == "mrope"
@lru_cache()
def import_model_classes():
model_arch_name_to_cls = {}
package_name = "sglang.srt.models"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
try:
module = importlib.import_module(name)
except Exception as e:
logger.warning(f"Ignore import error when loading {name}. {e}")
if crash_on_warnings():
raise ValueError(f"Ignore import error when loading {name}. {e}")
continue
if hasattr(module, "EntryClass"):
entry = module.EntryClass
if isinstance(
entry, list
): # To support multiple model classes in one module
for tmp in entry:
assert (
tmp.__name__ not in model_arch_name_to_cls
), f"Duplicated model implementation for {tmp.__name__}"
model_arch_name_to_cls[tmp.__name__] = tmp
else:
assert (
entry.__name__ not in model_arch_name_to_cls
), f"Duplicated model implementation for {entry.__name__}"
model_arch_name_to_cls[entry.__name__] = entry
return model_arch_name_to_cls
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
model_arch_name_to_cls = import_model_classes()
if model_arch not in model_arch_name_to_cls:
raise ValueError(
f"Unsupported architectures: {model_arch}. "
f"Supported list: {list(model_arch_name_to_cls.keys())}"
)
return model_arch_name_to_cls[model_arch]
# Monkey patch model loader
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
from torch import nn
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
from sglang.srt.model_loader.utils import (
get_architecture_class_name,
get_model_architecture,
)
def get_model(
*,
model_config: ModelConfig,
load_config: LoadConfig,
device_config: DeviceConfig,
) -> nn.Module:
loader = get_model_loader(load_config)
return loader.load_model(
model_config=model_config,
device_config=device_config,
)
__all__ = [
"get_model",
"get_model_loader",
"BaseModelLoader",
"get_architecture_class_name",
"get_model_architecture",
]
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py
# ruff: noqa: SIM117
import collections
import dataclasses
import fnmatch
import glob
import json
import logging
import math
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Type, cast
import gguf
import huggingface_hub
import numpy as np
import torch
from huggingface_hub import HfApi, hf_hub_download
from torch import nn
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_loader.utils import (
get_model_architecture,
set_default_torch_dtype,
)
from sglang.srt.model_loader.weight_utils import (
download_safetensors_index_file_from_hf,
download_weights_from_hf,
filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference,
get_gguf_extra_tensor_names,
get_quant_config,
gguf_quant_weights_iterator,
initialize_dummy_weights,
np_cache_weights_iterator,
pt_weights_iterator,
safetensors_weights_iterator,
)
from sglang.srt.utils import (
get_device_capability,
is_pin_memory_available,
set_weight_attrs,
)
@contextmanager
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
if target_device.type == "cpu":
# If target is CPU, no need to move anything
yield module
return
original_device_states: Dict[str, torch.device] = {}
# Store original device states and move parameters to GPU if they're on CPU
for name, p in module.named_parameters():
if p.device.type == "cpu":
original_device_states[name] = p.device
p.data = p.data.to(target_device)
# Parameters already on target device are not touched
try:
yield module
finally:
# Restore parameters to their original devices, ignoring new parameters
pin_memory = is_pin_memory_available()
for name, p in module.named_parameters():
if name in original_device_states:
original_device: torch.device = original_device_states[name]
if original_device.type == "cpu":
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(
size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_data.copy_(p.data)
p.data = cpu_data
else:
p.data = p.data.to(original_device)
# New parameters or parameters already on target device are untouched
logger = logging.getLogger(__name__)
def _get_quantization_config(
model_config: ModelConfig, load_config: LoadConfig
) -> Optional[QuantizationConfig]:
"""Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
major, minor = get_device_capability()
if major is not None and minor is not None:
assert 0 <= minor < 10
capability = major * 10 + minor
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} "
"is not supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}."
)
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}"
)
return quant_config
return None
def _initialize_model(
model_config: ModelConfig,
load_config: LoadConfig,
) -> nn.Module:
"""Initialize a model with the given configurations."""
model_class, _ = get_model_architecture(model_config)
quant_config = _get_quantization_config(model_config, load_config)
return model_class(
config=model_config.hf_config,
quant_config=quant_config,
)
class BaseModelLoader(ABC):
"""Base class for model loaders."""
def __init__(self, load_config: LoadConfig):
self.load_config = load_config
@abstractmethod
def download_model(self, model_config: ModelConfig) -> None:
"""Download a model so that it can be immediately loaded."""
raise NotImplementedError
@abstractmethod
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
) -> nn.Module:
"""Load a model with the given configurations."""
raise NotImplementedError
class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk."""
@dataclasses.dataclass
class Source:
"""A source for weights."""
model_or_path: str
"""The model ID or path."""
revision: Optional[str]
"""The optional model revision."""
prefix: str = ""
"""A prefix to prepend to all weights."""
fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(
f"Model loader extra config is not supported for "
f"load format {load_config.load_format}"
)
def _maybe_download_from_modelscope(
self, model: str, revision: Optional[str]
) -> Optional[str]:
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if "SGLANG_USE_MODELSCOPE" in os.environ:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from modelscope.hub.snapshot_download import snapshot_download
if not os.path.exists(model):
model_path = snapshot_download(
model_id=model,
cache_dir=self.load_config.download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
revision=revision,
ignore_file_pattern=self.load_config.ignore_patterns,
)
else:
model_path = model
return model_path
return None
def _prepare_weights(
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
) -> Tuple[str, List[str], bool]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
model_name_or_path = (
self._maybe_download_from_modelscope(model_name_or_path, revision)
or model_name_or_path
)
is_local = os.path.isdir(model_name_or_path)
load_format = self.load_config.load_format
use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
if load_format == LoadFormat.AUTO:
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == LoadFormat.SAFETENSORS:
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == LoadFormat.MISTRAL:
use_safetensors = True
allow_patterns = ["consolidated*.safetensors"]
index_file = "consolidated.safetensors.index.json"
elif load_format == LoadFormat.PT:
allow_patterns = ["*.pt"]
elif load_format == LoadFormat.NPCACHE:
allow_patterns = ["*.bin"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
if fall_back_to_pt:
allow_patterns += ["*.pt"]
if not is_local:
hf_folder = download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
else:
hf_folder = model_name_or_path
hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path,
index_file,
self.load_config.download_dir,
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file
)
else:
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`"
)
return hf_folder, hf_weights_files, use_safetensors
def _get_weights_iterator(
self, source: "Source"
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.revision, source.fall_back_to_pt
)
if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
weights_iterator = np_cache_weights_iterator(
source.model_or_path,
self.load_config.download_dir,
hf_folder,
hf_weights_files,
)
elif use_safetensors:
weights_iterator = safetensors_weights_iterator(hf_weights_files)
else:
weights_iterator = pt_weights_iterator(hf_weights_files)
# Apply the prefix.
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
def _get_all_weights(
self,
model_config: ModelConfig,
model: nn.Module,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
primary_weights = DefaultModelLoader.Source(
model_config.model_path,
model_config.revision,
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
)
yield from self._get_weights_iterator(primary_weights)
secondary_weights = cast(
Iterable[DefaultModelLoader.Source], getattr(model, "secondary_weights", ())
)
for source in secondary_weights:
yield from self._get_weights_iterator(source)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(
model_config.model_path, model_config.revision, fall_back_to_pt=True
)
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
) -> nn.Module:
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = _initialize_model(
model_config,
self.load_config,
)
model.load_weights(self._get_all_weights(model_config, model))
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
return model.eval()
class DummyModelLoader(BaseModelLoader):
"""Model loader that will set model weights to random values."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(
f"Model loader extra config is not supported for "
f"load format {load_config.load_format}"
)
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(
model_config,
self.load_config,
)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
return model.eval()
class ShardedStateLoader(BaseModelLoader):
"""
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/save_sharded_state.py` for creating a sharded checkpoint.
"""
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
extra_config = (
{}
if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy()
)
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
if extra_config:
raise ValueError(
f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{load_config.model_loader_extra_config.keys()}"
)
@staticmethod
def _filter_subtensors(tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Filter out all tensors that share the same memory or a subset of the
memory of another tensor.
"""
same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = (
collections.defaultdict(list)
)
for key, tensor in tensors.items():
if tensor.numel():
ptr = tensor.untyped_storage().data_ptr()
same_storage_groups[tensor.device, ptr].append((key, tensor))
def get_end_ptr(tensor: torch.Tensor) -> int:
return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
result: Dict[str, torch.Tensor] = {}
for group in same_storage_groups.values():
for k, t in group:
a, b = t.data_ptr(), get_end_ptr(t)
for k2, t2 in group:
if not t2.is_contiguous():
continue
a2, b2 = t2.data_ptr(), get_end_ptr(t2)
if a < a2 or b2 < b:
continue
if a2 < a or b < b2 or not t.is_contiguous():
break # t2 covers strictly more memory than t.
if k2 < k:
# Same tensors, keep the one with the smaller key.
break
else:
result[k] = t
return result
def _prepare_weights(self, model_name_or_path: str, revision: Optional[str]):
if os.path.isdir(model_name_or_path):
return model_name_or_path
else:
allow_patterns = ["*.safetensors"]
return download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model_path, model_config.revision)
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
) -> nn.Module:
from safetensors.torch import safe_open
from vllm.distributed import get_tensor_model_parallel_rank
local_model_path = self._prepare_weights(
model_config.model_path, model_config.revision
)
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
local_model_path,
self.pattern.format(rank=rank, part="*"),
)
filepaths = glob.glob(pattern)
if not filepaths:
# TODO: support un-sharded checkpoints too
raise ValueError(
f"Could not find checkpoint files '{pattern}', only "
f"pre-sharded checkpoints are currently supported!"
)
state_dict = self._filter_subtensors(model.state_dict())
for path in filepaths:
with safe_open(path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
tensor = f.get_tensor(key)
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data = state_dict[key].data
param_shape = state_dict[key].shape
for dim, size in enumerate(tensor.shape):
if size < param_shape[dim]:
param_data = param_data.narrow(dim, 0, size)
if tensor.shape != param_shape:
logger.warning(
"loading tensor of shape %s into "
"parameter '%s' of shape %s",
tensor.shape,
key,
param_shape,
)
param_data.copy_(tensor)
state_dict.pop(key)
if state_dict:
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
return model.eval()
@staticmethod
def save_model(
model: torch.nn.Module,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from safetensors.torch import save_file
from vllm.distributed import get_tensor_model_parallel_rank
if pattern is None:
pattern = ShardedStateLoader.DEFAULT_PATTERN
rank = get_tensor_model_parallel_rank()
part_idx = 0
total_size = 0
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
state_dict_part: Dict[str, torch.Tensor] = {}
for key, tensor in state_dict.items():
param_size = tensor.nelement() * tensor.element_size()
if max_size is not None and total_size + param_size > max_size:
filename = pattern.format(rank=rank, part=part_idx)
save_file(
state_dict_part,
os.path.join(path, filename),
)
part_idx += 1
total_size = 0
state_dict_part = {}
state_dict_part[key] = tensor
total_size += param_size
if len(state_dict_part) > 0:
filename = pattern.format(rank=rank, part=part_idx)
save_file(
state_dict_part,
os.path.join(path, filename),
)
class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quantization."""
possible_config_file_names = ["adapter_config.json"]
default_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
".fc1.",
".fc2.",
".dense.",
".query_key_value.",
".qkv_proj.",
".dense_h_to_4h.",
".dense_4h_to_h.",
".out_proj.",
]
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
# we don't need to quantize the whole model, only the target modules
# that are specified in the adapter config file. If the adapter config
# file is not provided, we will quantize the default modules.
if (
not load_config.model_loader_extra_config
or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config
):
self.target_modules = []
return
qlora_adapter = load_config.model_loader_extra_config[
"qlora_adapter_name_or_path"
]
config_file_path = self._get_config_file(qlora_adapter)
with open(config_file_path, "r") as f:
config = json.load(f)
self.target_modules = config["target_modules"]
def _get_config_file(self, qlora_adapter: str) -> str:
is_local = os.path.isdir(qlora_adapter)
config_file_path = None
if is_local:
for file in self.possible_config_file_names:
config_file_path = os.path.join(qlora_adapter, file)
if os.path.exists(config_file_path):
break
else:
hf_api = HfApi()
repo_files = hf_api.list_repo_files(repo_id=qlora_adapter)
for file in self.possible_config_file_names:
if file in repo_files:
config_file_path = hf_hub_download(
repo_id=qlora_adapter, filename=file
)
break
if not config_file_path:
raise ValueError(f"Cannot find adapter config file in {qlora_adapter}")
return config_file_path
def _get_weight_files(
self,
model_name_or_path: str,
allowed_patterns: List[str],
revision: Optional[str] = None,
) -> Tuple[List[str], str]:
"""Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern."""
is_local = os.path.isdir(model_name_or_path)
if is_local:
for pattern in allowed_patterns:
weight_files = glob.glob(os.path.join(model_name_or_path, pattern))
if weight_files:
return weight_files, pattern
else:
hf_api = HfApi()
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
for pattern in allowed_patterns:
matching_files = fnmatch.filter(repo_files, pattern)
if matching_files:
hf_folder = download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
[pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
return glob.glob(os.path.join(hf_folder, pattern)), pattern
raise RuntimeError(f"No model weights found in: `{model_name_or_path}`")
def _prepare_weights(
self, model_name_or_path: str, revision: Optional[str]
) -> Tuple[List[str], bool]:
"""Prepare weight files for the model."""
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
hf_weights_files, matched_pattern = self._get_weight_files(
model_name_or_path, allowed_patterns, revision
)
if matched_pattern != "*.safetensors":
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`"
)
return hf_weights_files, matched_pattern == "*.safetensors"
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors:
return safetensors_weights_iterator(hf_weights_files)
else:
return pt_weights_iterator(hf_weights_files)
def _get_quantized_weights_iterator(
self,
model_name_or_path: str,
revision: Optional[str],
pre_quant: bool,
load_8bit: bool,
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
as well as the quantization state dictionary."""
# only load the bitsandbytes module when needed
try:
import bitsandbytes
if bitsandbytes.__version__ < "0.44.0":
raise ImportError(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.44.0."
)
except ImportError as err:
raise ImportError(
"Please install bitsandbytes>=0.44.0 via "
"`pip install bitsandbytes>=0.44.0` to use "
"bitsandbytes quantizer."
) from err
hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision
)
quant_state_dict: Dict[str, Any] = {}
if pre_quant:
if load_8bit:
return (
self._quantized_8bit_generator(
hf_weights_files, use_safetensors, quant_state_dict
),
quant_state_dict,
)
else:
return (
self._quantized_4bit_generator(
hf_weights_files, use_safetensors, quant_state_dict
),
quant_state_dict,
)
return (
self._unquantized_generator(
hf_weights_files, use_safetensors, quant_state_dict
),
quant_state_dict,
)
def _quantized_8bit_generator(
self, hf_weights_files, use_safetensors, quant_state_dict
) -> Generator:
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors
):
if not weight_name.lower().endswith(".scb"):
continue
weight_key = weight_name.lower().replace(".scb", ".qweight")
quant_state_dict[weight_key] = weight_tensor
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors
):
if not weight_name.endswith((".weight", ".bias")):
continue
qweight_name = weight_name.replace(".weight", ".qweight")
if qweight_name in quant_state_dict:
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
yield qweight_name, weight_tensor
else:
yield weight_name, weight_tensor
def _quantized_4bit_generator(
self, hf_weights_files, use_safetensors, quant_state_dict
) -> Generator:
from bitsandbytes.functional import QuantState
# First iterate over all quant state weights
weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith((".weight", ".bias")):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
if "quant_state.bitsandbytes" in weight_name:
temp_state_dict[weight_name] = weight_tensor.cpu().data
else:
temp_state_dict[weight_name] = weight_tensor
# Closure to parse quant_state for each prequant weight
def _parse_quant_state(param_name: str, temp_state_dict: Dict) -> QuantState:
quant_state = {}
for k in temp_state_dict:
if param_name + "." in k:
quant_state[k] = temp_state_dict[k]
return QuantState.from_dict(quant_state, device="cuda")
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors
):
if not weight_name.endswith((".weight", ".bias")):
continue
if (f"{weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict) or (
f"{weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict
):
quant_state = _parse_quant_state(weight_name, temp_state_dict)
weight_name = weight_name.replace(".weight", ".qweight")
quant_state_dict[weight_name] = quant_state
yield weight_name.replace(".weight", ".qweight"), weight_tensor
else:
yield weight_name, weight_tensor
def _unquantized_generator(
self, hf_weights_files, use_safetensors, quant_state_dict
) -> Generator:
from bitsandbytes.functional import quantize_4bit
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors
):
if any(
target_module in weight_name for target_module in self.target_modules
) and weight_name.endswith(".weight"):
weight_name = weight_name.replace(".weight", ".qweight")
if any(
module in weight_name
for module in self.column_parallel_weights_modules
):
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[..., start_index:end_index]
else:
total_size = weight_tensor.size(0)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[start_index:end_index, ...]
# bitsandbytes requires data in GPU
if weight_sub_tensor.is_cuda:
loaded_weight = weight_sub_tensor
else:
loaded_weight = weight_sub_tensor.cuda()
# remove the following after the issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
if loaded_weight.is_contiguous() is False:
loaded_weight = loaded_weight.contiguous()
with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
loaded_weight, compress_statistics=True, quant_type="nf4"
)
quant_state_dict[weight_name] = quant_state
else:
processed_weight = weight_tensor
yield weight_name, processed_weight
def _load_weights(self, model_config: ModelConfig, model: nn.Module) -> None:
if not hasattr(model, "load_weights"):
raise AttributeError(
"The required method 'load_weights' is not defined in class"
f" {type(model).__name__}."
)
if not hasattr(model, "bitsandbytes_stacked_params_mapping"):
raise AttributeError(
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet."
)
if len(self.target_modules) == 0:
if hasattr(model, "default_bitsandbytes_target_modules"):
self.target_modules = model.default_bitsandbytes_target_modules
else:
self.target_modules = self.default_target_modules
if hasattr(model, "column_parallel_weights_modules"):
self.column_parallel_weights_modules = model.column_parallel_weights_modules
else:
self.column_parallel_weights_modules = []
self.model_type = type(model).__name__
logger.info(
"Loading weights with BitsAndBytes quantization. " " May take a while ..."
)
quant_config = getattr(model_config.hf_config, "quantization_config", None)
pre_quant = False
if quant_config is not None:
quant_method = quant_config.get("quant_method")
if quant_method == "bitsandbytes":
pre_quant = True
else:
raise ValueError(
f"BitsAndBytes loader does not support {quant_method} "
"quantization"
)
# The quant_states in pre_quantized models cannot work with a split
# weight tensor. So TP does not work with pre_quantized bnb models.
if pre_quant and get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"Prequant BitsAndBytes models with TP is not supported."
"Please try with PP."
)
load_8bit = False
if pre_quant:
load_8bit = quant_config.get("load_in_8bit", False)
qweight_iterator, quant_state_dict = self._get_quantized_weights_iterator(
model_config.model_path, model_config.revision, pre_quant, load_8bit
)
model.load_weights(qweight_iterator)
torch.cuda.empty_cache()
param_dict = dict(model.named_parameters())
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
for quant_param_name in quant_state_dict:
non_stacked_param_name = quant_param_name
shard_index = 0
for shard_name, (
weight_name,
index,
) in model.bitsandbytes_stacked_params_mapping.items():
if shard_name in quant_param_name:
shard_index = index
quant_param_name = quant_param_name.replace(shard_name, weight_name)
break
if quant_param_name not in param_dict:
raise ValueError(
f"Parameter {quant_param_name} not found in the model."
)
if quant_param_name not in stacked_quant_state_dict:
stacked_quant_state_dict[quant_param_name] = {}
stacked_quant_state_dict[quant_param_name][shard_index] = quant_state_dict[
non_stacked_param_name
]
# save quant_states and offsets as the attributes of the parameters
for param_name, param in param_dict.items():
if param_name in stacked_quant_state_dict:
quant_states = stacked_quant_state_dict[param_name]
set_weight_attrs(param, {"bnb_quant_state": quant_states})
pack_ratio = getattr(param, "pack_factor", -1)
if pack_ratio == -1:
raise ValueError(f"pack_factor not set for parameter {param_name}.")
num_elements = [0] * len(quant_states)
for seq, quant_state in quant_states.items():
num_elements[seq] = math.prod(quant_state.shape) // pack_ratio
offsets = np.concatenate(([0], np.cumsum(num_elements)))
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
if load_8bit:
set_weight_attrs(
param, {"matmul_state": [None] * len(quant_states)}
)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model_path, model_config.revision)
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(
model_config,
self.load_config,
)
self._load_weights(model_config, model)
return model.eval()
class GGUFModelLoader(BaseModelLoader):
"""
Model loader that can load GGUF files. This is useful for loading models
that are quantized with GGUF and saved in the GGUF format. This loader
supports loading both full models and sharded models.
"""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(
f"Model loader extra config is not supported for "
f"load format {load_config.load_format}"
)
def _prepare_weights(self, model_name_or_path: str):
if os.path.isfile(model_name_or_path):
return model_name_or_path
else:
raise ValueError(f"{model_name_or_path} is not a file.")
def _get_gguf_weights_map(self, model_config: ModelConfig):
"""
GGUF uses this naming convention for their tensors from HF checkpoint:
`blk.N.BB.weight` and `blk.N.BB.bias`
where N signifies the block number of a layer, and BB signifies the
attention/mlp layer components.
See "Standardized tensor names" in
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
"""
config = model_config.hf_config
model_type = config.model_type
# hack: ggufs have a different name than transformers
if model_type == "cohere":
model_type = "command-r"
arch = None
for key, value in gguf.MODEL_ARCH_NAMES.items():
if value == model_type:
arch = key
break
if arch is None:
raise RuntimeError(f"Unknown gguf model_type: {model_type}")
num_layers = config.num_hidden_layers
name_map = gguf.get_tensor_name_map(arch, num_layers)
with torch.device("meta"):
dummy_model = AutoModelForCausalLM.from_config(config)
state_dict = dummy_model.state_dict()
gguf_to_hf_name_map = {}
for hf_name in state_dict:
name, suffix = hf_name.rsplit(".", 1)
gguf_name = name_map.get_name(name)
gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
return gguf_to_hf_name_map
def _get_weights_iterator(
self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
return gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model_path)
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
) -> nn.Module:
local_model_path = self._prepare_weights(model_config.model_path)
gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights
if "lm_head.weight" in get_gguf_extra_tensor_names(
local_model_path, gguf_weights_map
):
model_config.hf_config.update({"tie_word_embeddings": True})
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config)
model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map)
)
return model
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
if isinstance(load_config.load_format, type):
return load_config.load_format(load_config)
if load_config.load_format == LoadFormat.DUMMY:
return DummyModelLoader(load_config)
if load_config.load_format == LoadFormat.SHARDED_STATE:
return ShardedStateLoader(load_config)
if load_config.load_format == LoadFormat.BITSANDBYTES:
return BitsAndBytesModelLoader(load_config)
if load_config.load_format == LoadFormat.GGUF:
return GGUFModelLoader(load_config)
return DefaultModelLoader(load_config)
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/utils.py
"""Utilities for selecting and loading models."""
import contextlib
from typing import Tuple, Type
import torch
from torch import nn
from sglang.srt.configs.model_config import ModelConfig
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
from sglang.srt.models.registry import ModelRegistry
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"]
if (
model_config.quantization is not None
and model_config.quantization not in mixtral_supported
and "MixtralForCausalLM" in architectures
):
architectures = ["QuantMixtralForCausalLM"]
return ModelRegistry.resolve_model_cls(architectures)
def get_architecture_class_name(model_config: ModelConfig) -> str:
return get_model_architecture(model_config)[1]
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
"""Utilities for downloading and initializing model weights."""
import fnmatch
import glob
import hashlib
import json
import logging
import os
import tempfile
from collections import defaultdict
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import filelock
import gguf
import huggingface_hub.constants
import numpy as np
import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
from vllm.distributed import get_tensor_model_parallel_rank
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
from sglang.srt.utils import print_warning_once
logger = logging.getLogger(__name__)
# use system-level temp directory for file locks, so that multiple users
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir = tempfile.gettempdir()
def enable_hf_transfer():
"""automatically activates hf_transfer"""
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
try:
# enable hf hub transfer if available
import hf_transfer # type: ignore # noqa
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
except ImportError:
pass
enable_hf_transfer()
class DisabledTqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
lock_dir = cache_dir or temp_dir
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
model_name = model_name_or_path.replace("/", "-")
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
# add hash to avoid conflict with old users' lock files
lock_file_name = hash_name + model_name + ".lock"
# mode 0o666 is required for the filelock to be shared across users
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
return lock
def _shared_pointers(tensors):
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
failing = []
for _, names in ptrs.items():
if len(names) > 1:
failing.append(names)
return failing
def convert_bin_to_safetensor_file(
pt_filename: str,
sf_filename: str,
) -> None:
loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
shared = _shared_pointers(loaded)
for shared_weights in shared:
for name in shared_weights[1:]:
loaded.pop(name)
# For tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}
dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True)
save_file(loaded, sf_filename, metadata={"format": "pt"})
# check file size
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size
if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(
f"""The file size different is more than 1%:
- {sf_filename}: {sf_size}
- {pt_filename}: {pt_size}
"""
)
# check if the tensors are the same
reloaded = load_file(sf_filename)
for k in loaded:
pt_tensor = loaded[k]
sf_tensor = reloaded[k]
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
# TODO(woosuk): Move this to other place.
def get_quant_config(
model_config: ModelConfig, load_config: LoadConfig
) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
# GGUF doesn't have config file
if model_config.quantization == "gguf":
return quant_cls.from_config({})
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config", None)
# some vision model may keep quantization_config in their text_config
hf_text_config = getattr(model_config.hf_config, "text_config", None)
if hf_quant_config is None and hf_text_config is not None:
hf_quant_config = getattr(hf_text_config, "quantization_config", None)
if hf_quant_config is None:
# compressed-tensors uses a compressions_config
hf_quant_config = getattr(model_config.hf_config, "compression_config", None)
if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config)
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.
if model_config.quantization == "bitsandbytes":
if (
not load_config.model_loader_extra_config
or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config
):
return quant_cls.from_config({"adapter_name_or_path": ""})
model_name_or_path = load_config.model_loader_extra_config[
"qlora_adapter_name_or_path"
]
else:
model_name_or_path = model_config.model_path
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, load_config.download_dir):
hf_folder = snapshot_download(
model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=load_config.download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
tqdm_class=DisabledTqdm,
)
else:
hf_folder = model_name_or_path
possible_config_filenames = quant_cls.get_config_filenames()
# If the quantization config is not found, use the default config.
if not possible_config_filenames:
return quant_cls()
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
quant_config_files = [
f for f in config_files if any(f.endswith(x) for x in possible_config_filenames)
]
if len(quant_config_files) == 0:
raise ValueError(f"Cannot find the config file for {model_config.quantization}")
if len(quant_config_files) > 1:
raise ValueError(
f"Found multiple config files for {model_config.quantization}: "
f"{quant_config_files}"
)
quant_config_file = quant_config_files[0]
with open(quant_config_file) as f:
config = json.load(f)
if model_config.quantization == "bitsandbytes":
config["adapter_name_or_path"] = model_name_or_path
elif model_config.quantization == "modelopt":
if config["producer"]["name"] == "modelopt":
return quant_cls.from_config(config)
else:
raise ValueError(
f"Unsupported quantization config"
f" found for {model_config.quantization} in {f}."
)
return quant_cls.from_config(config)
def download_weights_from_hf(
model_name_or_path: str,
cache_dir: Optional[str],
allow_patterns: List[str],
revision: Optional[str] = None,
ignore_patterns: Optional[Union[str, List[str]]] = None,
) -> str:
"""Download model weights from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
allow_patterns (List[str]): The allowed patterns for the
weight files. Files matched by any of the patterns will be
downloaded.
revision (Optional[str]): The revision of the model.
ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
filter out the weight files. Files matched by any of the patterns
will be ignored.
Returns:
str: The path to the downloaded model weights.
"""
if not huggingface_hub.constants.HF_HUB_OFFLINE:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break
logger.info("Using model weights format %s", allow_patterns)
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(
model_name_or_path,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
cache_dir=cache_dir,
tqdm_class=DisabledTqdm,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
)
return hf_folder
def download_safetensors_index_file_from_hf(
model_name_or_path: str,
index_file: str,
cache_dir: Optional[str],
revision: Optional[str] = None,
) -> None:
"""Download hf safetensors index file from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
revision (Optional[str]): The revision of the model.
"""
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
try:
# Download the safetensors index file.
hf_hub_download(
repo_id=model_name_or_path,
filename=index_file,
cache_dir=cache_dir,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
)
# If file not found on remote or locally, we should not fail since
# only some models will have index_file.
except huggingface_hub.utils.EntryNotFoundError:
logger.info("No %s found in remote.", index_file)
except huggingface_hub.utils.LocalEntryNotFoundError:
logger.info("No %s found in local cache.", index_file)
# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# So, we use the index_file to
# look up which safetensors files should be used.
def filter_duplicate_safetensors_files(
hf_weights_files: List[str], hf_folder: str, index_file: str
) -> List[str]:
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
index_file_name = os.path.join(hf_folder, index_file)
if not os.path.isfile(index_file_name):
return hf_weights_files
# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
with open(index_file_name) as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = set()
for weight_name in weight_map:
weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name]))
# Filter out any fields that are not found in the index file.
hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index]
return hf_weights_files
def filter_files_not_needed_for_inference(hf_weights_files: List[str]) -> List[str]:
"""
Exclude files that are not needed for inference.
See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
"""
blacklist = [
"training_args.bin",
"optimizer.bin",
"optimizer.pt",
"scheduler.pt",
"scaler.pt",
]
hf_weights_files = [
f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist)
]
return hf_weights_files
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
def np_cache_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str],
hf_folder: str,
hf_weights_files: List[str],
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model np files.
Will dump the model weights to numpy files if they are not already dumped.
"""
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, "np")
os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, "weight_names.json")
# Use file lock to prevent multiple processes from
# dumping the same model weights to numpy at the same time.
with get_lock(model_name_or_path, cache_dir):
if not os.path.exists(weight_names_file):
weight_names: List[str] = []
for bin_file in tqdm(
hf_weights_files,
desc="Loading np_cache checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
param_path = os.path.join(np_folder, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
weight_names.append(name)
with open(weight_names_file, "w") as f:
json.dump(weight_names, f)
with open(weight_names_file) as f:
weight_names = json.load(f)
for name in weight_names:
param_path = os.path.join(np_folder, name)
with open(param_path, "rb") as f:
param = np.load(f)
yield name, torch.from_numpy(param)
def safetensors_weights_iterator(
hf_weights_files: List[str],
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
def pt_weights_iterator(
hf_weights_files: List[str],
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
for bin_file in tqdm(
hf_weights_files,
desc="Loading pt checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu")
yield from state.items()
del state
torch.cuda.empty_cache()
def get_gguf_extra_tensor_names(
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
) -> List[str]:
reader = gguf.GGUFReader(gguf_file)
expected_gguf_keys = set(gguf_to_hf_name_map.keys())
exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
extra_keys = expected_gguf_keys - exact_gguf_keys
return [gguf_to_hf_name_map[key] for key in extra_keys]
def gguf_quant_weights_iterator(
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""
Iterate over the quant weights in the model gguf files and convert
them to torch tensors
"""
reader = gguf.GGUFReader(gguf_file)
for tensor in reader.tensors:
if tensor.name in gguf_to_hf_name_map:
weight_type = tensor.tensor_type
name = gguf_to_hf_name_map[tensor.name]
if weight_type.name != "F32":
weight_type_name = name.replace("weight", "qweight_type")
weight_type = torch.tensor(weight_type)
yield weight_type_name, weight_type
for tensor in reader.tensors:
if tensor.name in gguf_to_hf_name_map:
weight = tensor.data
weight_type = tensor.tensor_type
name = gguf_to_hf_name_map[tensor.name]
if weight_type.name != "F32":
name = name.replace("weight", "qweight")
param = torch.tensor(weight)
yield name, param
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
"""convert PySafeSlice object from safetensors to torch.Tensor
PySafeSlice object supports indexing, which is done before loading the
actual tensor and can reduce the amount of memory being read into the
memory. However, it does not support more advanced functionalities
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
tensor with these more complicated operators, we need to convert to
tensor first.
"""
if not isinstance(x, torch.Tensor):
x = x[:]
return x
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
try:
if param.numel() == 1 and loaded_weight.numel() == 1:
# Sometimes scalar values aren't considered tensors with shapes
# so if both param and loaded_weight are a scalar,
# "broadcast" instead of copy
param.data.fill_(loaded_weight.item())
else:
assert param.size() == loaded_weight.size(), (
f"Attempted to load weight ({loaded_weight.size()}) "
f"into parameter ({param.size()})"
)
param.data.copy_(loaded_weight)
except Exception:
# NOTE: This exception is added for the purpose of setting breakpoint to
# debug weight loading issues.
raise
def row_parallel_weight_loader(
param: torch.Tensor, loaded_weight: torch.Tensor
) -> None:
"""Load weights that are row-parallelized."""
tp_rank = get_tensor_model_parallel_rank()
shard_dim = 0 if param.dim() != 1 else None
if shard_dim is not None:
shard_size = param.data.shape[shard_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
return default_weight_loader(param, loaded_weight)
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
"""Create a weight loader that shards the weights along the given axis"""
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
tp_rank = get_tensor_model_parallel_rank()
shard_size = param.data.shape[shard_axis]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size)
return default_weight_loader(param, loaded_weight)
return loader
def composed_weight_loader(
loader: LoaderFunction, fn: Callable[[torch.Tensor], torch.Tensor]
) -> LoaderFunction:
"""Create a weight loader that post-processes the weights after loading"""
def composed_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
loader(param, loaded_weight)
param.data.copy_(fn(param))
return
return composed_loader
def initialize_dummy_weights(
model: torch.nn.Module,
low: float = -1e-3,
high: float = 1e-3,
seed: int = 1234,
) -> None:
"""Initialize model weights with random values.
The model weights must be randomly initialized for accurate performance
measurements. Additionally, the model weights should not cause NaNs in the
forward pass. We empirically found that initializing the weights with
values between -1e-3 and 1e-3 works well for most models.
We use per-parameter random seed, so that dummy weights are consistent,
even if the model is partitioned across multiple devices. When the seed
is fixed, the random values generated by this function only depends on
the parameter's number of elements and its data type.
"""
for param in model.state_dict().values():
if torch.is_floating_point(param):
generator = torch.Generator(device=param.data.device)
generator.manual_seed(seed)
if torch.finfo(param.data.dtype).bits < 16:
# uniform_ doesn't support < 16-bit datatypes (FP8)
dtype = param.data.dtype
tmp_param = param.data.to(torch.float16)
tmp_param = tmp_param.uniform_(low, high, generator=generator).to(dtype)
param.data.copy_(tmp_param)
else:
param.uniform_(low, high, generator=generator)
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
"""Remap the name of FP8 k/v_scale parameters.
This function handles the remapping of FP8 k/v_scale parameter names.
It detects if the given name ends with a suffix and attempts to remap
it to the expected name format in the model. If the remapped name is not
found in the params_dict, a warning is printed and None is returned.
Args:
name (str): The original loaded checkpoint parameter name.
params_dict (dict): Dictionary containing the model's named parameters.
Returns:
str: The remapped parameter name if successful, or the original name
if no remapping is needed.
None: If the remapped name is not found in params_dict.
"""
if name.endswith(".kv_scale"):
print_warning_once(
"DEPRECATED. Found kv_scale in the checkpoint. "
"This format is deprecated in favor of separate k_scale and "
"v_scale tensors and will be removed in a future release. "
"Functionally, we will remap kv_scale to k_scale and duplicate "
"k_scale to v_scale"
)
# NOTE: we remap the deprecated kv_scale to k_scale
remapped_name = name.replace(".kv_scale", ".attn.k_scale")
if remapped_name not in params_dict:
print_warning_once(
f"Found kv_scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_name}). kv_scale is "
"not loaded."
)
return None
return remapped_name
possible_scale_names = [".k_scale", ".v_scale"]
for scale_name in possible_scale_names:
if name.endswith(scale_name):
remapped_name = name.replace(scale_name, f".attn{scale_name}")
if remapped_name not in params_dict:
print_warning_once(
f"Found {scale_name} in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_name}). {scale_name} is "
"not loaded."
)
return None
return remapped_name
# If there were no matches, return the untouched param name
return name
...@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import ( ...@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
...@@ -329,7 +329,6 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -329,7 +329,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
position_embedding: str, position_embedding: str,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -404,13 +403,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): ...@@ -404,13 +403,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__( def __init__(
self, self,
config, config,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
if config.hidden_size == 4096: # baichuan2 7b if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", cache_config, quant_config) super().__init__(config, "ROPE", quant_config)
else: # baichuan 13b, baichuan2 13b else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", cache_config, quant_config) super().__init__(config, "ALIBI", quant_config)
EntryClass = [BaichuanForCausalLM] EntryClass = [BaichuanForCausalLM]
...@@ -23,7 +23,6 @@ from torch import nn ...@@ -23,7 +23,6 @@ from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
...@@ -41,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -41,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
LoraConfig = None LoraConfig = None
...@@ -50,7 +50,6 @@ class GLMAttention(nn.Module): ...@@ -50,7 +50,6 @@ class GLMAttention(nn.Module):
self, self,
config, config,
layer_id: int = 0, layer_id: int = 0,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -186,7 +185,6 @@ class GLMBlock(nn.Module): ...@@ -186,7 +185,6 @@ class GLMBlock(nn.Module):
self, self,
config, config,
layer_id: int, layer_id: int,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -203,7 +201,7 @@ class GLMBlock(nn.Module): ...@@ -203,7 +201,7 @@ class GLMBlock(nn.Module):
) )
# Self attention. # Self attention.
self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config) self.self_attention = GLMAttention(config, layer_id, quant_config)
self.hidden_dropout = config.hidden_dropout self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output # Layernorm on the attention output
...@@ -258,7 +256,6 @@ class GLMTransformer(nn.Module): ...@@ -258,7 +256,6 @@ class GLMTransformer(nn.Module):
def __init__( def __init__(
self, self,
config, config,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -269,10 +266,7 @@ class GLMTransformer(nn.Module): ...@@ -269,10 +266,7 @@ class GLMTransformer(nn.Module):
# Transformer layers. # Transformer layers.
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [GLMBlock(config, i, quant_config) for i in range(self.num_layers)]
GLMBlock(config, i, cache_config, quant_config)
for i in range(self.num_layers)
]
) )
if self.post_layer_norm: if self.post_layer_norm:
...@@ -306,7 +300,6 @@ class ChatGLMM(nn.Module): ...@@ -306,7 +300,6 @@ class ChatGLMM(nn.Module):
def __init__( def __init__(
self, self,
config, config,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -318,7 +311,7 @@ class ChatGLMM(nn.Module): ...@@ -318,7 +311,7 @@ class ChatGLMM(nn.Module):
self.num_layers = config.num_layers self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, cache_config, quant_config) self.encoder = GLMTransformer(config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size) self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
...@@ -357,15 +350,13 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -357,15 +350,13 @@ class ChatGLMForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: ChatGLMConfig, config: ChatGLMConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoraConfig] = None,
): ):
super().__init__() super().__init__()
self.config: ChatGLMConfig = config self.config: ChatGLMConfig = config
self.quant_config = quant_config self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
self.transformer = ChatGLMM(config, cache_config, quant_config) self.transformer = ChatGLMM(config, quant_config)
self.lm_head = self.transformer.output_layer self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -49,7 +49,6 @@ from vllm.distributed import ( ...@@ -49,7 +49,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
...@@ -62,6 +61,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig ...@@ -62,6 +61,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
...@@ -318,7 +318,6 @@ class CohereForCausalLM(nn.Module): ...@@ -318,7 +318,6 @@ class CohereForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
......
...@@ -25,7 +25,6 @@ from vllm.distributed import ( ...@@ -25,7 +25,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.fused_moe_triton import fused_moe from sglang.srt.layers.fused_moe_triton import fused_moe
...@@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
...@@ -366,7 +366,6 @@ class DbrxForCausalLM(nn.Module): ...@@ -366,7 +366,6 @@ class DbrxForCausalLM(nn.Module):
self, self,
config: DbrxConfig, config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment