Unverified Commit 19d2135c authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Use model loader from vllm (#459)

parent ced77c66
""" """
Usage: python3 srt_example_yi_vl.py Usage: python3 srt_example_yi_vl.py
Requirements: transformers==4.38
""" """
import sglang as sgl import sglang as sgl
......
...@@ -41,6 +41,7 @@ from sglang.utils import get_exception_traceback ...@@ -41,6 +41,7 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger("model_rpc") logger = logging.getLogger("model_rpc")
vllm_default_logger.setLevel(logging.WARN) vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.WARN)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
class ModelRpcServer: class ModelRpcServer:
......
import importlib import importlib
import importlib.resources import importlib.resources
import inspect
import logging import logging
import pkgutil import pkgutil
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import List from typing import List, Optional, Type
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import initialize_model_parallel from vllm.distributed import initialize_model_parallel
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.model_loader import get_model
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
QUANTIZATION_CONFIG_MAPPING = {
"awq": AWQConfig,
"gptq": GPTQConfig,
"marlin": MarlinConfig,
}
logger = logging.getLogger("model_runner") logger = logging.getLogger("model_runner")
...@@ -31,35 +26,6 @@ logger = logging.getLogger("model_runner") ...@@ -31,35 +26,6 @@ logger = logging.getLogger("model_runner")
global_server_args_dict = {} global_server_args_dict = {}
@lru_cache()
def import_model_classes():
model_arch_name_to_cls = {}
package_name = "sglang.srt.models"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
module = importlib.import_module(name)
if hasattr(module, "EntryClass"):
model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
return model_arch_name_to_cls
def get_model_cls_by_arch_name(model_arch_names):
model_arch_name_to_cls = import_model_classes()
model_class = None
for arch in model_arch_names:
if arch in model_arch_name_to_cls:
model_class = model_arch_name_to_cls[arch]
break
else:
raise ValueError(
f"Unsupported architectures: {arch}. "
f"Supported list: {list(model_arch_name_to_cls.keys())}"
)
return model_class
@dataclass @dataclass
class InputMetadata: class InputMetadata:
model_runner: "ModelRunner" model_runner: "ModelRunner"
...@@ -287,49 +253,32 @@ class ModelRunner: ...@@ -287,49 +253,32 @@ class ModelRunner:
self.is_multimodal_model = is_multimodal_model(self.model_config) self.is_multimodal_model = is_multimodal_model(self.model_config)
def load_model(self): def load_model(self):
"""See also vllm/model_executor/model_loader.py::get_model"""
# Select model class
architectures = getattr(self.model_config.hf_config, "architectures", [])
model_class = get_model_cls_by_arch_name(architectures)
logger.info(f"Rank {self.tp_rank}: load weight begin.") logger.info(f"Rank {self.tp_rank}: load weight begin.")
# Load weights device_config = DeviceConfig()
quant_config = None load_config = LoadConfig()
vllm_model_config = VllmModelConfig(
quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None) model=self.model_config.path,
if quant_cfg is not None: tokenizer=None,
quant_method = quant_cfg.get("quant_method", "").lower() tokenizer_mode=None,
# compat: autogptq >=0.8.0 use checkpoint_format: str trust_remote_code=self.model_config.trust_remote_code,
# compat: autogptq <=0.7.1 is_marlin_format: bool dtype=torch.float16,
is_format_marlin = quant_cfg.get( seed=42,
"checkpoint_format" revision=self.model_config.revision,
) == "marlin" or quant_cfg.get("is_marlin_format", False) skip_tokenizer_init=True,
)
# Use marlin if the GPTQ model is serialized in marlin format. if self.model_config.model_overide_args is not None:
if quant_method == "gptq" and is_format_marlin: vllm_model_config.hf_config.update(self.model_config.model_overide_args)
quant_method = "marlin"
self.model = get_model(
quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method) model_config=vllm_model_config,
device_config=device_config,
if quant_config_class is None: load_config=load_config,
raise ValueError(f"Unsupported quantization method: {quant_method}") lora_config=None,
vision_language_config=None,
quant_config = quant_config_class.from_config(quant_cfg) parallel_config=None,
logger.info(f"quant_config: {quant_config}") scheduler_config=None,
)
with set_default_torch_dtype(torch.float16):
with torch.device("cuda"):
model = model_class(
config=self.model_config.hf_config, quant_config=quant_config
)
model.load_weights(
self.model_config.path,
cache_dir=None,
load_format=self.load_format,
revision=None,
)
self.model = model.eval()
logger.info(f"Rank {self.tp_rank}: load weight end.") logger.info(f"Rank {self.tp_rank}: load weight end.")
def profile_max_num_token(self, total_gpu_memory): def profile_max_num_token(self, total_gpu_memory):
...@@ -455,3 +404,30 @@ class ModelRunner: ...@@ -455,3 +404,30 @@ class ModelRunner:
return self.forward_prefill(batch) return self.forward_prefill(batch)
else: else:
raise ValueError(f"Invaid forward mode: {forward_mode}") raise ValueError(f"Invaid forward mode: {forward_mode}")
@lru_cache()
def import_model_classes():
model_arch_name_to_cls = {}
package_name = "sglang.srt.models"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
module = importlib.import_module(name)
if hasattr(module, "EntryClass"):
model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
return model_arch_name_to_cls
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
model_arch_name_to_cls = import_model_classes()
if model_arch not in model_arch_name_to_cls:
raise ValueError(
f"Unsupported architectures: {model_arch}. "
f"Supported list: {list(model_arch_name_to_cls.keys())}"
)
return model_arch_name_to_cls[model_arch]
# Monkey patch model loader
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
\ No newline at end of file
...@@ -15,10 +15,9 @@ class ModelConfig: ...@@ -15,10 +15,9 @@ class ModelConfig:
self.path = path self.path = path
self.trust_remote_code = trust_remote_code self.trust_remote_code = trust_remote_code
self.revision = revision self.revision = revision
self.hf_config = get_config(self.path, trust_remote_code, revision) self.model_overide_args = model_overide_args
self.hf_config = get_config(self.path, trust_remote_code, revision,
if model_overide_args is not None: model_overide_args=model_overide_args)
self.hf_config.update(model_overide_args)
if context_length is not None: if context_length is not None:
self.context_len = context_length self.context_len = context_length
...@@ -44,4 +43,4 @@ class ModelConfig: ...@@ -44,4 +43,4 @@ class ModelConfig:
self.num_key_value_heads = self.num_attention_heads self.num_key_value_heads = self.num_attention_heads
self.hidden_size = self.hf_config.hidden_size self.hidden_size = self.hf_config.hidden_size
self.num_hidden_layers = self.hf_config.num_hidden_layers self.num_hidden_layers = self.hf_config.num_hidden_layers
self.vocab_size = self.hf_config.vocab_size self.vocab_size = self.hf_config.vocab_size
\ No newline at end of file
...@@ -18,9 +18,12 @@ ...@@ -18,9 +18,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/commandr.py#L1
# This file is based on the LLama model definition file in transformers # This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model.""" """PyTorch Cohere model."""
from typing import Optional, Tuple from typing import Optional, Tuple, Iterable
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -41,11 +44,11 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf ...@@ -41,11 +44,11 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
@torch.compile @torch.compile
...@@ -324,13 +327,7 @@ class CohereForCausalLM(nn.Module): ...@@ -324,13 +327,7 @@ class CohereForCausalLM(nn.Module):
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
) )
def load_weights( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -341,9 +338,7 @@ class CohereForCausalLM(nn.Module): ...@@ -341,9 +338,7 @@ class CohereForCausalLM(nn.Module):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params = set() loaded_params = set()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in weights:
model_name_or_path, cache_dir, load_format, revision
):
for param_name, shard_name, shard_id in stacked_params_mapping: for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name: if shard_name not in name:
continue continue
......
# Adapted from: # Adapted from:
# https://github.com/vllm-project/vllm/blob/14ccd94c89d0ffd9da283545d93ab1dfea5da340/vllm/model_executor/models/dbrx.py # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/dbrx.py#L1
# coding=utf-8 # coding=utf-8
from typing import Optional from typing import Iterable, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -24,12 +24,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -24,12 +24,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.models.dbrx_config import DbrxConfig
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class DbrxRouter(nn.Module): class DbrxRouter(nn.Module):
...@@ -377,13 +377,7 @@ class DbrxForCausalLM(nn.Module): ...@@ -377,13 +377,7 @@ class DbrxForCausalLM(nn.Module):
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
def load_weights( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
expert_params_mapping = [ expert_params_mapping = [
( (
"ws" if weight_name in ["w1", "v1"] else "w2s", "ws" if weight_name in ["w1", "v1"] else "w2s",
...@@ -392,9 +386,7 @@ class DbrxForCausalLM(nn.Module): ...@@ -392,9 +386,7 @@ class DbrxForCausalLM(nn.Module):
for weight_name in ["w1", "v1", "w2"] for weight_name in ["w1", "v1", "w2"]
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in weights:
model_name_or_path, cache_dir, load_format, revision
):
for param_name, weight_name in expert_params_mapping: for param_name, weight_name in expert_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
# Adapted from:
# https://github.com/vllm-project/vllm/blob/14ccd94c89d0ffd9da283545d93ab1dfea5da340/vllm/transformers_utils/configs/dbrx.py
# yapf: disable
# ruff: noqa: E501
# coding=utf-8
# Copied from
# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py
"""Dbrx configuration."""
# FIXME: remove this once vllm releases a new version
from typing import Any, Optional
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class DbrxAttentionConfig(PretrainedConfig):
"""Configuration class for Dbrx Attention.
[`DbrxAttention`] class. It is used to instantiate attention layers
according to the specified arguments, defining the layers architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
attn_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability for the attention layers.
clip_qkv (`float`, *optional*, defaults to None):
If not `None`, clip the queries, keys, and values in the attention layer to this value.
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
rope_theta (float): The base frequency for rope.
"""
def __init__(
self,
attn_pdrop: float = 0,
clip_qkv: Optional[float] = None,
kv_n_heads: int = 1,
rope_theta: float = 10000.0,
**kwargs: Any,
):
super().__init__(**kwargs)
self.attn_pdrop = attn_pdrop
self.clip_qkv = clip_qkv
self.kv_n_heads = kv_n_heads
self.rope_theta = rope_theta
for k in ["model_type"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
raise ValueError(f"Found unknown {kwargs=}")
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: str, **kwargs: Any
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
if config_dict.get("model_type") == "dbrx":
config_dict = config_dict["attn_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class DbrxFFNConfig(PretrainedConfig):
"""Configuration class for Dbrx FFN.
[`DbrxFFN`] class. It is used to instantiate feedforward layers according to
the specified arguments, defining the layers architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
ffn_act_fn (dict, optional): A dict specifying activation function for the FFN.
The dict should have a key 'name' with the value being the name of
the activation function along with any additional keyword arguments.
ffn_hidden_size (int, optional): The hidden size of the feedforward network.
moe_num_experts (int, optional): The number of experts in the mixture of experts layer.
moe_top_k (int, optional): The number of experts to use in the mixture of experts layer.
moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer.
moe_loss_weight (float, optional): The loss weight for the mixture of experts layer.
moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights.
uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment.
This should only be used for benchmarking purposes.
"""
def __init__(
self,
ffn_act_fn: Optional[dict] = None,
ffn_hidden_size: int = 3584,
moe_num_experts: int = 4,
moe_top_k: int = 1,
moe_jitter_eps: Optional[float] = None,
moe_loss_weight: float = 0.01,
moe_normalize_expert_weights: Optional[float] = 1,
uniform_expert_assignment: bool = False,
**kwargs: Any,
):
super().__init__()
if ffn_act_fn is None:
ffn_act_fn = {"name": "silu"}
self.ffn_act_fn = ffn_act_fn
self.ffn_hidden_size = ffn_hidden_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.moe_jitter_eps = moe_jitter_eps
self.moe_loss_weight = moe_loss_weight
self.moe_normalize_expert_weights = moe_normalize_expert_weights
self.uniform_expert_assignment = uniform_expert_assignment
for k in ["model_type"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
raise ValueError(f"Found unknown {kwargs=}")
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: str, **kwargs: Any
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
if config_dict.get("model_type") == "dbrx":
config_dict = config_dict["ffn_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class DbrxConfig(PretrainedConfig):
"""Configuration class for Dbrx.
[`DbrxModel`]. It is used to instantiate a Dbrx model according to the
specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
d_model (`int`, *optional*, defaults to 6144):
Dimensionality of the embeddings and hidden states.
n_heads (`int`, *optional*, defaults to 48):
Number of attention heads for each attention layer in the Transformer encoder.
n_layers (`int`, *optional*, defaults to 40):
Number of hidden layers in the Transformer encoder.
max_seq_len (`int`, *optional*, defaults to 32768):
The maximum sequence length of the model.
vocab_size (`int`, *optional*, defaults to 100352):
Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
the `inputs_ids` passed when calling [`DbrxModel`].
resid_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability applied to the attention output before combining with residual.
emb_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability for the embedding layer.
attn_config (`dict`, *optional*):
A dictionary used to configure the model's attention module.
ffn_config (`dict`, *optional*):
A dictionary used to configure the model's FFN module.
use_cache (`bool`, *optional*, defaults to `False`):
Whether or not the model should return the last key/values attentions (not used by all models).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabling this will also
allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
Example:
```python
>>> from transformers import DbrxConfig, DbrxModel
>>> # Initializing a Dbrx configuration
>>> configuration = DbrxConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = DbrxModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "dbrx"
attribute_map = {
"num_attention_heads": "n_heads",
"hidden_size": "d_model",
"num_hidden_layers": "n_layers",
"max_position_embeddings": "max_seq_len",
}
def __init__(
self,
d_model: int = 2048,
n_heads: int = 16,
n_layers: int = 24,
max_seq_len: int = 2048,
vocab_size: int = 32000,
resid_pdrop: float = 0.0,
emb_pdrop: float = 0.0,
attn_config: Optional[DbrxAttentionConfig] = None,
ffn_config: Optional[DbrxFFNConfig] = None,
use_cache: bool = True,
initializer_range: float = 0.02,
output_router_logits: bool = False,
router_aux_loss_coef: float = 0.05,
**kwargs: Any,
):
if attn_config is None:
self.attn_config = DbrxAttentionConfig()
elif isinstance(attn_config, dict):
self.attn_config = DbrxAttentionConfig(**attn_config)
else:
self.attn_config = attn_config
if ffn_config is None:
self.ffn_config = DbrxFFNConfig()
elif isinstance(ffn_config, dict):
self.ffn_config = DbrxFFNConfig(**ffn_config)
else:
self.ffn_config = ffn_config
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.resid_pdrop = resid_pdrop
self.emb_pdrop = emb_pdrop
self.use_cache = use_cache
self.initializer_range = initializer_range
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
raise ValueError(
"tie_word_embeddings is not supported for Dbrx models."
)
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# Adapted from: # Adapted from:
# https://github.com/vllm-project/vllm/blob/d65fac2738f0287a41955b45df76a2d5a919bff6/vllm/model_executor/models/gemma.py # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/gemma.py#L1
"""Inference-only Gemma model compatible with HuggingFace weights.""" """Inference-only Gemma model compatible with HuggingFace weights."""
from typing import Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -18,11 +18,11 @@ from vllm.model_executor.layers.linear import ( ...@@ -18,11 +18,11 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class GemmaMLP(nn.Module): class GemmaMLP(nn.Module):
...@@ -285,13 +285,7 @@ class GemmaForCausalLM(nn.Module): ...@@ -285,13 +285,7 @@ class GemmaForCausalLM(nn.Module):
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
) )
def load_weights( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -302,9 +296,7 @@ class GemmaForCausalLM(nn.Module): ...@@ -302,9 +296,7 @@ class GemmaForCausalLM(nn.Module):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params = set() loaded_params = set()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in weights:
model_name_or_path, cache_dir, load_format, revision
):
for param_name, shard_name, shard_id in stacked_params_mapping: for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name: if shard_name not in name:
continue continue
......
# Adapted from # Adapted from
# https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1 # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple, Iterable
import torch import torch
from torch import nn from torch import nn
...@@ -20,11 +20,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -20,11 +20,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
...@@ -152,6 +152,10 @@ class LlamaDecoderLayer(nn.Module): ...@@ -152,6 +152,10 @@ class LlamaDecoderLayer(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention( self.self_attn = LlamaAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -270,13 +274,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -270,13 +274,7 @@ class LlamaForCausalLM(nn.Module):
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
def load_weights( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -286,9 +284,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -286,9 +284,7 @@ class LlamaForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in weights:
model_name_or_path, cache_dir, load_format, revision
):
if "rotary_emb.inv_freq" in name or "projector" in name: if "rotary_emb.inv_freq" in name or "projector" in name:
continue continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
......
"""Inference-only LLaVa model compatible with HuggingFace weights.""" """Inference-only LLaVa model compatible with HuggingFace weights."""
from typing import List, Optional from typing import List, Iterable, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -8,6 +8,7 @@ from torch import nn ...@@ -8,6 +8,7 @@ from torch import nn
from transformers import CLIPVisionModel, LlavaConfig from transformers import CLIPVisionModel, LlavaConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.managers.router.infer_batch import ForwardMode from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.router.model_runner import InputMetadata
...@@ -17,7 +18,6 @@ from sglang.srt.mm_utils import ( ...@@ -17,7 +18,6 @@ from sglang.srt.mm_utils import (
unpad_image_shape, unpad_image_shape,
) )
from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class LlavaLlamaForCausalLM(nn.Module): class LlavaLlamaForCausalLM(nn.Module):
...@@ -233,13 +233,7 @@ class LlavaLlamaForCausalLM(nn.Module): ...@@ -233,13 +233,7 @@ class LlavaLlamaForCausalLM(nn.Module):
elif input_metadata.forward_mode == ForwardMode.DECODE: elif input_metadata.forward_mode == ForwardMode.DECODE:
return self.language_model(input_ids, positions, input_metadata) return self.language_model(input_ids, positions, input_metadata)
def load_weights( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
# load clip vision model by cfg['mm_vision_tower']: # load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir # huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = self.config.mm_vision_tower vision_path = self.config.mm_vision_tower
...@@ -272,9 +266,8 @@ class LlavaLlamaForCausalLM(nn.Module): ...@@ -272,9 +266,8 @@ class LlavaLlamaForCausalLM(nn.Module):
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
} }
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( weights = list(weights)
model_name_or_path, cache_dir, load_format, revision for name, loaded_weight in weights:
):
# FIXME: why projector weights read two times? # FIXME: why projector weights read two times?
if "projector" in name or "vision_tower" in name: if "projector" in name or "vision_tower" in name:
for weight_name, param_name in projector_weights.items(): for weight_name, param_name in projector_weights.items():
...@@ -285,9 +278,7 @@ class LlavaLlamaForCausalLM(nn.Module): ...@@ -285,9 +278,7 @@ class LlavaLlamaForCausalLM(nn.Module):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load language model # load language model
self.language_model.load_weights( self.language_model.load_weights(weights)
model_name_or_path, cache_dir, load_format, revision
)
monkey_path_clip_vision_embed_forward() monkey_path_clip_vision_embed_forward()
......
"""Inference-only LLaVa model compatible with HuggingFace weights.""" """Inference-only LLaVa model compatible with HuggingFace weights."""
from typing import List, Optional from typing import List, Iterable, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -8,6 +8,7 @@ from torch import nn ...@@ -8,6 +8,7 @@ from torch import nn
from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, MistralConfig from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, MistralConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.managers.router.infer_batch import ForwardMode from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.router.model_runner import InputMetadata
...@@ -17,7 +18,6 @@ from sglang.srt.mm_utils import ( ...@@ -17,7 +18,6 @@ from sglang.srt.mm_utils import (
unpad_image_shape, unpad_image_shape,
) )
from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class LlavaMistralForCausalLM(nn.Module): class LlavaMistralForCausalLM(nn.Module):
...@@ -246,13 +246,7 @@ class LlavaMistralForCausalLM(nn.Module): ...@@ -246,13 +246,7 @@ class LlavaMistralForCausalLM(nn.Module):
elif input_metadata.forward_mode == ForwardMode.DECODE: elif input_metadata.forward_mode == ForwardMode.DECODE:
return self.language_model(input_ids, positions, input_metadata) return self.language_model(input_ids, positions, input_metadata)
def load_weights( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
# load clip vision model by cfg['mm_vision_tower']: # load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir # huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = self.config.mm_vision_tower vision_path = self.config.mm_vision_tower
...@@ -285,9 +279,8 @@ class LlavaMistralForCausalLM(nn.Module): ...@@ -285,9 +279,8 @@ class LlavaMistralForCausalLM(nn.Module):
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
} }
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( weights = list(weights)
model_name_or_path, cache_dir, load_format, revision for name, loaded_weight in weights:
):
# FIXME: why projector weights read two times? # FIXME: why projector weights read two times?
if "projector" in name or "vision_tower" in name: if "projector" in name or "vision_tower" in name:
for weight_name, param_name in projector_weights.items(): for weight_name, param_name in projector_weights.items():
...@@ -298,9 +291,7 @@ class LlavaMistralForCausalLM(nn.Module): ...@@ -298,9 +291,7 @@ class LlavaMistralForCausalLM(nn.Module):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load language model # load language model
self.language_model.load_weights( self.language_model.load_weights(weights)
model_name_or_path, cache_dir, load_format, revision
)
monkey_path_clip_vision_embed_forward() monkey_path_clip_vision_embed_forward()
......
"""Inference-only LLaVa model compatible with HuggingFace weights.""" """Inference-only LLaVa model compatible with HuggingFace weights."""
from typing import List, Optional from typing import List, Iterable, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -8,6 +8,7 @@ from torch import nn ...@@ -8,6 +8,7 @@ from torch import nn
from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, Qwen2Config from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, Qwen2Config
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.managers.router.infer_batch import ForwardMode from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.router.model_runner import InputMetadata
...@@ -17,7 +18,6 @@ from sglang.srt.mm_utils import ( ...@@ -17,7 +18,6 @@ from sglang.srt.mm_utils import (
unpad_image_shape, unpad_image_shape,
) )
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class LlavaQwenForCausalLM(nn.Module): class LlavaQwenForCausalLM(nn.Module):
...@@ -246,13 +246,7 @@ class LlavaQwenForCausalLM(nn.Module): ...@@ -246,13 +246,7 @@ class LlavaQwenForCausalLM(nn.Module):
elif input_metadata.forward_mode == ForwardMode.DECODE: elif input_metadata.forward_mode == ForwardMode.DECODE:
return self.language_model(input_ids, positions, input_metadata) return self.language_model(input_ids, positions, input_metadata)
def load_weights( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
# load clip vision model by cfg['mm_vision_tower']: # load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir # huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = self.config.mm_vision_tower vision_path = self.config.mm_vision_tower
...@@ -285,9 +279,8 @@ class LlavaQwenForCausalLM(nn.Module): ...@@ -285,9 +279,8 @@ class LlavaQwenForCausalLM(nn.Module):
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
} }
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( weights = list(weights)
model_name_or_path, cache_dir, load_format, revision for name, loaded_weight in weights:
):
# FIXME: why projector weights read two times? # FIXME: why projector weights read two times?
if "projector" in name or "vision_tower" in name: if "projector" in name or "vision_tower" in name:
for weight_name, param_name in projector_weights.items(): for weight_name, param_name in projector_weights.items():
...@@ -298,9 +291,7 @@ class LlavaQwenForCausalLM(nn.Module): ...@@ -298,9 +291,7 @@ class LlavaQwenForCausalLM(nn.Module):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load language model # load language model
self.language_model.load_weights( self.language_model.load_weights(weights)
model_name_or_path, cache_dir, load_format, revision
)
monkey_path_clip_vision_embed_forward() monkey_path_clip_vision_embed_forward()
......
"""Inference-only LLaVa video model compatible with HuggingFace weights.""" """Inference-only LLaVa video model compatible with HuggingFace weights."""
import os from typing import List, Iterable, Optional, Tuple
from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig from transformers import CLIPVisionModel, LlavaConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.managers.router.infer_batch import ForwardMode from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.router.model_runner import InputMetadata
...@@ -18,7 +18,6 @@ from sglang.srt.mm_utils import ( ...@@ -18,7 +18,6 @@ from sglang.srt.mm_utils import (
unpad_image_shape, unpad_image_shape,
) )
from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class LlavaVidForCausalLM(nn.Module): class LlavaVidForCausalLM(nn.Module):
...@@ -65,7 +64,6 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -65,7 +64,6 @@ class LlavaVidForCausalLM(nn.Module):
pad_ids = pad_value * ( pad_ids = pad_value * (
(new_image_feature_len + len(pad_value)) // len(pad_value) (new_image_feature_len + len(pad_value)) // len(pad_value)
) )
# print(input_ids)
offset = input_ids.index(self.config.image_token_index) offset = input_ids.index(self.config.image_token_index)
# old_len + pad_len - 1, because we need to remove image_token_id # old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids = ( new_input_ids = (
...@@ -200,13 +198,7 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -200,13 +198,7 @@ class LlavaVidForCausalLM(nn.Module):
elif input_metadata.forward_mode == ForwardMode.DECODE: elif input_metadata.forward_mode == ForwardMode.DECODE:
return self.language_model(input_ids, positions, input_metadata) return self.language_model(input_ids, positions, input_metadata)
def load_weights( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
# load clip vision model by cfg['mm_vision_tower']: # load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir # huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = self.config.mm_vision_tower vision_path = self.config.mm_vision_tower
...@@ -244,9 +236,8 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -244,9 +236,8 @@ class LlavaVidForCausalLM(nn.Module):
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
} }
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( weights = list(weights)
model_name_or_path, cache_dir, load_format, revision for name, loaded_weight in weights:
):
# FIXME: why projector weights read two times? # FIXME: why projector weights read two times?
if "projector" in name or "vision_tower" in name: if "projector" in name or "vision_tower" in name:
for weight_name, param_name in projector_weights.items(): for weight_name, param_name in projector_weights.items():
...@@ -261,9 +252,7 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -261,9 +252,7 @@ class LlavaVidForCausalLM(nn.Module):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load language model # load language model
self.language_model.load_weights( self.language_model.load_weights(weights)
model_name_or_path, cache_dir, load_format, revision
)
monkey_path_clip_vision_embed_forward() monkey_path_clip_vision_embed_forward()
......
# Adapted from # Adapted from
# https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/mixtral.py#L1 # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral_quant.py#L1
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import Optional from typing import Iterable, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -25,11 +25,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -25,11 +25,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class MixtralMLP(nn.Module): class MixtralMLP(nn.Module):
...@@ -107,7 +108,7 @@ class MixtralMoE(nn.Module): ...@@ -107,7 +108,7 @@ class MixtralMoE(nn.Module):
] ]
) )
self.gate = ReplicatedLinear( self.gate = ReplicatedLinear(
config.hidden_size, self.num_total_experts, bias=False, linear_method=None config.hidden_size, self.num_total_experts, bias=False, quant_config=None
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -333,13 +334,7 @@ class MixtralForCausalLM(nn.Module): ...@@ -333,13 +334,7 @@ class MixtralForCausalLM(nn.Module):
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
def load_weights( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -348,13 +343,7 @@ class MixtralForCausalLM(nn.Module): ...@@ -348,13 +343,7 @@ class MixtralForCausalLM(nn.Module):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in weights:
model_name_or_path,
cache_dir,
load_format,
revision,
fall_back_to_pt=False,
):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
......
from typing import Any, Dict, Optional # Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
from typing import Any, Dict, Optional, Iterable, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -17,11 +19,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -17,11 +19,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class QWenMLP(nn.Module): class QWenMLP(nn.Module):
...@@ -245,22 +247,14 @@ class QWenLMHeadModel(nn.Module): ...@@ -245,22 +247,14 @@ class QWenLMHeadModel(nn.Module):
) )
return next_tokens return next_tokens
def load_weights( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("gate_up_proj", "w2", 0), ("gate_up_proj", "w2", 0),
("gate_up_proj", "w1", 1), ("gate_up_proj", "w1", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in weights:
model_name_or_path, cache_dir, load_format, revision
):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
......
# Adapted from llama2.py # Adapted from llama2.py
# Modify details for the adaptation of Qwen2 model. # Modify details for the adaptation of Qwen2 model.
"""Inference-only Qwen2 model compatible with HuggingFace weights.""" """Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple, Iterable
import torch import torch
from torch import nn from torch import nn
...@@ -19,11 +19,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -19,11 +19,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
Qwen2Config = None Qwen2Config = None
...@@ -271,13 +271,7 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -271,13 +271,7 @@ class Qwen2ForCausalLM(nn.Module):
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
def load_weights( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -287,9 +281,7 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -287,9 +281,7 @@ class Qwen2ForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in weights:
model_name_or_path, cache_dir, load_format, revision
):
if "rotary_emb.inv_freq" in name or "projector" in name: if "rotary_emb.inv_freq" in name or "projector" in name:
continue continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
......
# This code is based on: # Adapted from:
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/stablelm.py # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1
"""Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b) """Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
model compatible with HuggingFace weights.""" model compatible with HuggingFace weights."""
from typing import Optional, Tuple from typing import Optional, Tuple, Iterable
import torch import torch
from torch import nn from torch import nn
...@@ -20,11 +20,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -20,11 +20,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class StablelmMLP(nn.Module): class StablelmMLP(nn.Module):
...@@ -245,13 +245,7 @@ class StableLmForCausalLM(nn.Module): ...@@ -245,13 +245,7 @@ class StableLmForCausalLM(nn.Module):
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
def load_weights( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -261,9 +255,7 @@ class StableLmForCausalLM(nn.Module): ...@@ -261,9 +255,7 @@ class StableLmForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in weights:
model_name_or_path, cache_dir, load_format, revision
):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
......
"""Inference-only Yi-VL model.""" """Inference-only Yi-VL model."""
import os from typing import Tuple, Iterable
from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import CLIPVisionModel, LlavaConfig from transformers import CLIPVisionModel, LlavaConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llava import ( from sglang.srt.models.llava import (
LlavaLlamaForCausalLM, LlavaLlamaForCausalLM,
clip_vision_embed_forward,
monkey_path_clip_vision_embed_forward, monkey_path_clip_vision_embed_forward,
) )
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class YiVLForCausalLM(LlavaLlamaForCausalLM): class YiVLForCausalLM(LlavaLlamaForCausalLM):
def __init__(self, *args, **kwargs): def __init__(
self.config = kwargs["config"] self, config, quant_config = None,
super().__init__(self.config) ) -> None:
super().__init__(config, quant_config)
self.multi_modal_projector = YiVLMultiModalProjector(self.config) self.multi_modal_projector = YiVLMultiModalProjector(self.config)
self.vision_tower_subfolder = self.config.mm_vision_tower.replace( self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
"./", "" "./", ""
) # Everything after "./" ) # Everything after "./"
def load_weights( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B) # We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
self.vision_tower = CLIPVisionModel.from_pretrained( self.vision_tower = CLIPVisionModel.from_pretrained(
model_name_or_path, self.config._name_or_path,
torch_dtype=torch.float16, torch_dtype=torch.float16,
subfolder=self.vision_tower_subfolder, subfolder=self.vision_tower_subfolder,
).cuda() ).cuda()
...@@ -68,9 +61,8 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM): ...@@ -68,9 +61,8 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
} }
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( weights = list(weights)
model_name_or_path, cache_dir, load_format, revision for name, loaded_weight in weights:
):
if "projector" in name or "vision_tower" in name: if "projector" in name or "vision_tower" in name:
for weight_name, param_name in projector_weights.items(): for weight_name, param_name in projector_weights.items():
if weight_name in name: if weight_name in name:
...@@ -80,9 +72,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM): ...@@ -80,9 +72,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load language model # load language model
self.language_model.load_weights( self.language_model.load_weights(weights)
model_name_or_path, cache_dir, load_format, revision
)
monkey_path_clip_vision_embed_forward() monkey_path_clip_vision_embed_forward()
...@@ -103,7 +93,7 @@ class YiVLMultiModalProjector(nn.Module): ...@@ -103,7 +93,7 @@ class YiVLMultiModalProjector(nn.Module):
def forward(self, image_features): def forward(self, image_features):
hidden_states = self.linear_1(image_features) hidden_states = self.linear_1(image_features)
hidden_state = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states) hidden_states = self.linear_2(hidden_states)
hidden_states = self.ln_2(hidden_states) hidden_states = self.ln_2(hidden_states)
......
# The PR(https://github.com/vllm-project/vllm/pull/4097) of vllm borken the sglang code.
# In order to adapt to the latest code without modifying too much code,
# copied the previous vllm/model_executor/weight_utils.py
# Copied in https://github.com/vllm-project/vllm/blob/05434764cd99990035779cf9a4ed86623b528825/vllm/model_executor/weight_utils.py
"""Utilities for downloading and initializing model weights."""
import fnmatch
import glob
import hashlib
import json
import os
from collections import defaultdict
from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union
import filelock
import huggingface_hub.constants
import numpy as np
import torch
from huggingface_hub import HfFileSystem, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (
QuantizationConfig,
get_quantization_config,
)
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
logger = init_logger(__name__)
# use system-level temp directory for file locks, so that multiple users
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir = (
os.environ.get("TMPDIR")
or os.environ.get("TEMP")
or os.environ.get("TMP")
or "/tmp/"
)
def enable_hf_transfer():
"""automatically activates hf_transfer"""
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
try:
# enable hf hub transfer if available
import hf_transfer # type: ignore # noqa
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
except ImportError:
pass
enable_hf_transfer()
class Disabledtqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
lock_dir = cache_dir or temp_dir
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
model_name = model_name_or_path.replace("/", "-")
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
# add hash to avoid conflict with old users' lock files
lock_file_name = hash_name + model_name + ".lock"
# mode 0o666 is required for the filelock to be shared across users
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
return lock
def _shared_pointers(tensors):
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
failing = []
for _, names in ptrs.items():
if len(names) > 1:
failing.append(names)
return failing
def convert_bin_to_safetensor_file(
pt_filename: str,
sf_filename: str,
) -> None:
loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
shared = _shared_pointers(loaded)
for shared_weights in shared:
for name in shared_weights[1:]:
loaded.pop(name)
# For tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}
dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True)
save_file(loaded, sf_filename, metadata={"format": "pt"})
# check file size
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size
if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(
f"""The file size different is more than 1%:
- {sf_filename}: {sf_size}
- {pt_filename}: {pt_size}
"""
)
# check if the tensors are the same
reloaded = load_file(sf_filename)
for k in loaded:
pt_tensor = loaded[k]
sf_tensor = reloaded[k]
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
# TODO(woosuk): Move this to other place.
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config", None)
if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config)
model_name_or_path = model_config.model
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, model_config.download_dir):
hf_folder = snapshot_download(
model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=model_config.download_dir,
tqdm_class=Disabledtqdm,
)
else:
hf_folder = model_name_or_path
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
quant_config_files = [
f
for f in config_files
if any(f.endswith(x) for x in quant_cls.get_config_filenames())
]
if len(quant_config_files) == 0:
raise ValueError(f"Cannot find the config file for {model_config.quantization}")
if len(quant_config_files) > 1:
raise ValueError(
f"Found multiple config files for {model_config.quantization}: "
f"{quant_config_files}"
)
quant_config_file = quant_config_files[0]
with open(quant_config_file, "r") as f:
config = json.load(f)
return quant_cls.from_config(config)
def prepare_hf_model_weights(
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
fall_back_to_pt: bool = True,
revision: Optional[str] = None,
) -> Tuple[str, List[str], bool]:
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path) and load_format != "tensorizer"
use_safetensors = False
# Some quantized models use .pt files for storing the weights.
if load_format == "auto":
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == "safetensors":
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == "pt":
allow_patterns = ["*.pt"]
elif load_format == "npcache":
allow_patterns = ["*.bin"]
elif load_format == "tensorizer":
allow_patterns = ["*.tensors"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
if fall_back_to_pt:
allow_patterns += ["*.pt"]
if not is_local and load_format != "tensorizer":
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break
logger.info(f"Using model weights format {allow_patterns}")
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(
model_name_or_path,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
tqdm_class=Disabledtqdm,
revision=revision,
)
else:
hf_folder = model_name_or_path
hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if not use_safetensors:
# Exclude files that are not needed for inference.
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
blacklist = [
"training_args.bin",
"optimizer.bin",
"optimizer.pt",
"scheduler.pt",
"scaler.pt",
]
hf_weights_files = [
f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist)
]
if load_format == "tensorizer":
return hf_folder, hf_weights_files, use_safetensors
if len(hf_weights_files) == 0:
raise RuntimeError(f"Cannot find any model weights with `{model_name_or_path}`")
return hf_folder, hf_weights_files, use_safetensors
def hf_model_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: Union[Tuple, str] = "auto",
revision: Optional[str] = None,
fall_back_to_pt: Optional[bool] = True,
) -> Iterator[Tuple[str, torch.Tensor]]:
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
model_name_or_path,
cache_dir=cache_dir,
load_format=load_format,
fall_back_to_pt=fall_back_to_pt,
revision=revision,
)
if load_format == "npcache":
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, "np")
os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, "weight_names.json")
# Use file lock to prevent multiple processes from
# dumping the same model weights to numpy at the same time.
with get_lock(model_name_or_path, cache_dir):
if not os.path.exists(weight_names_file):
weight_names = []
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
param_path = os.path.join(np_folder, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
weight_names.append(name)
with open(weight_names_file, "w") as f:
json.dump(weight_names, f)
with open(weight_names_file, "r") as f:
weight_names = json.load(f)
for name in weight_names:
param_path = os.path.join(np_folder, name)
with open(param_path, "rb") as f:
param = np.load(f)
yield name, torch.from_numpy(param)
elif load_format == "tensorizer":
from vllm.model_executor.tensorizer_loader import (
TensorDeserializer,
open_stream,
tensorizer_warning,
)
tensorizer_args = load_format.params
tensorizer_warning(
"Deserializing HuggingFace models is not optimized for "
"loading on vLLM, as tensorizer is forced to load to CPU. "
"Consider deserializing a vLLM model instead for faster "
"load times. See the examples/tensorize_vllm_model.py example "
"script for serializing vLLM models."
)
deserializer_args = tensorizer_args.deserializer_params
stream_params = tensorizer_args.stream_params
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
with TensorDeserializer(stream, **deserializer_args, device="cpu") as state:
for name, param in state.items():
yield name, param
del state
elif use_safetensors:
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
else:
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
yield name, param
del state
torch.cuda.empty_cache()
def kv_cache_scales_loader(
filename: str,
tp_rank: int,
tp_size: int,
num_hidden_layers: int,
model_type: Optional[str],
) -> Iterable[Tuple[int, float]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
KV cache scaling factors. The serialization should represent a dictionary
whose keys are the TP ranks and values are another dictionary mapping layers
to their KV cache scaling factors.
Keep this function in sync with the output of examples/fp8/extract_scales.py
"""
try:
with open(filename) as f:
context = {
"model_type": model_type,
"num_hidden_layers": num_hidden_layers,
"tp_rank": tp_rank,
"tp_size": tp_size,
}
schema_dct = json.load(f)
schema = QuantParamSchema.model_validate(schema_dct, context=context)
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
return layer_scales_map.items()
except FileNotFoundError:
logger.error(f"File or directory '{filename}' not found.")
except json.JSONDecodeError:
logger.error(f"Error decoding JSON in file '{filename}'.")
except Exception as e:
logger.error(f"An error occurred while reading '{filename}': {e}")
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger.warning(
"Defaulting to KV cache scaling factors = 1.0 "
f"for all layers in TP rank {tp_rank} "
"as an error occurred during loading."
)
return []
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
"""convert PySafeSlice object from safetensors to torch.Tensor
PySafeSlice object supports indexing, which is done before loading the
actual tensor and can reduce the amount of memory being read into the
memory. However, it does not support more advanced functionalities
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
tensor with these more complicated operators, we need to convert to
tensor first.
"""
if not isinstance(x, torch.Tensor):
x = x[:]
return x
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
def initialize_dummy_weights(
model: torch.nn.Module,
low: float = -1e-3,
high: float = 1e-3,
) -> None:
"""Initialize model weights with random values.
The model weights must be randomly initialized for accurate performance
measurements. Additionally, the model weights should not cause NaNs in the
forward pass. We empirically found that initializing the weights with
values between -1e-3 and 1e-3 works well for most models.
"""
for param in model.state_dict().values():
if torch.is_floating_point(param):
param.data.uniform_(low, high)
...@@ -141,7 +141,7 @@ def encode_frame(frame): ...@@ -141,7 +141,7 @@ def encode_frame(frame):
def encode_video_base64(video_path, num_frames=16): def encode_video_base64(video_path, num_frames=16):
import cv2 import cv2 # pip install opencv-python-headless
cap = cv2.VideoCapture(video_path) cap = cv2.VideoCapture(video_path)
if not cap.isOpened(): if not cap.isOpened():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment