Unverified Commit 69e1d2fb authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

[Core] Refactor model loading code (#4097)

parent 05434764
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
TensorizerConfig, VisionLanguageConfig)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, VisionLanguageConfig)
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
......@@ -23,20 +23,20 @@ class ExecutorBase(ABC):
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
tensorizer_config: Optional[TensorizerConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.speculative_config = speculative_config
self.tensorizer_config = tensorizer_config
self._init_executor()
......
......@@ -35,12 +35,12 @@ class GPUExecutor(ExecutorBase):
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
tensorizer_config=self.tensorizer_config,
is_driver_worker=True,
)
self.driver_worker.init_device()
......
......@@ -147,6 +147,7 @@ class RayGPUExecutor(ExecutorBase):
model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
load_config = copy.deepcopy(self.load_config)
device_config = copy.deepcopy(self.device_config)
lora_config = copy.deepcopy(self.lora_config)
cache_config = copy.deepcopy(self.cache_config)
......@@ -165,12 +166,12 @@ class RayGPUExecutor(ExecutorBase):
scheduler_config=scheduler_config,
device_config=device_config,
cache_config=cache_config,
load_config=load_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=lora_config,
vision_language_config=vision_language_config,
tensorizer_config=self.tensorizer_config,
))
# Initialize the driver worker with the Worker class.
......@@ -187,7 +188,7 @@ class RayGPUExecutor(ExecutorBase):
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
tensorizer_config=self.tensorizer_config,
load_config=self.load_config,
is_driver_worker=True,
)
......
"""Utilities for selecting and loading models."""
import contextlib
from typing import Tuple, Type
import torch
from torch import nn
from vllm.config import DeviceConfig, ModelConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
from vllm.model_executor.tensorizer_loader import (
ParameterizedLoadFormat, is_vllm_serialized_tensorizer,
load_with_tensorizer)
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
_VISION_MODEL_CLASSES = [
LlavaForConditionalGeneration,
]
@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def _get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_architecture_class_name(model_config: ModelConfig) -> str:
return _get_model_architecture(model_config)[1]
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> nn.Module:
lora_config = kwargs.get("lora_config", None)
vision_language_config = kwargs.get("vision_language_config", None)
tensorizer_config = kwargs.get("tensorizer_config", None)
model_class = _get_model_architecture(model_config)[0]
# Get the (maybe quantized) linear method.
linear_method = None
if model_config.quantization is not None:
quant_config = get_quant_config(model_config)
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} is not "
"supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}.")
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
linear_method = quant_config.get_linear_method()
with _set_default_torch_dtype(model_config.dtype):
# Create a model instance.
# The weights will be initialized as empty tensors.
extra_kwargs = {}
if hasattr(model_class, "supported_lora_modules"):
extra_kwargs["lora_config"] = lora_config
elif lora_config:
raise ValueError(
f"Model {model_class.__name__} does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
elif model_class in _VISION_MODEL_CLASSES:
extra_kwargs["vision_language_config"] = vision_language_config
with torch.device(device_config.device):
if (model_config.load_format == "tensorizer"
and is_vllm_serialized_tensorizer(tensorizer_config)):
extra_kwargs["linear_method"] = linear_method
tensorizer_config.model_class = model_class
tensorizer_config.hf_config = model_config.hf_config
tensorizer_config.dtype = model_config.dtype
model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
return model.eval()
model = model_class(config=model_config.hf_config,
linear_method=linear_method,
**extra_kwargs)
if model_config.load_format == "dummy":
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
else:
# Load the weights from the cached or downloaded files.
if model_config.load_format == "tensorizer":
# Provide a dynamic load format for `model.load_weights`
# to retain tensorizer args from CLI.
model_config.load_format = ParameterizedLoadFormat(
model_config.load_format)
model_config.load_format.params = (
tensorizer_config._construct_tensorizer_args())
model.load_weights(
model_config.model,
model_config.download_dir,
model_config.load_format,
model_config.revision,
)
return model.eval()
from typing import Optional
from torch import nn
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
get_model_loader)
from vllm.model_executor.model_loader.utils import (
get_architecture_class_name, get_model_architecture)
def get_model(
*, model_config: ModelConfig, load_config: LoadConfig,
device_config: DeviceConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
loader = get_model_loader(load_config)
return loader.load_model(model_config=model_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config)
__all__ = [
"get_model", "get_model_loader", "BaseModelLoader",
"get_architecture_class_name", "get_model_architecture"
]
# ruff: noqa: SIM117
import copy
import glob
import os
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple,
Type)
import torch
from torch import nn
from vllm.config import (VLLM_USE_MODELSCOPE, DeviceConfig, LoadConfig,
LoadFormat, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VisionLanguageConfig)
from vllm.logger import init_logger
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer,
tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (get_model_architecture,
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, filter_files_not_needed_for_inference,
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
if TYPE_CHECKING:
from vllm.model_executor.layers.linear import LinearMethodBase
_VISION_MODEL_CLASSES = [
LlavaForConditionalGeneration,
]
logger = init_logger(__name__)
def _get_linear_method(
model_config: ModelConfig,
load_config: LoadConfig) -> Optional["LinearMethodBase"]:
"""Get the (maybe quantized) linear method."""
linear_method = None
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} is not "
"supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}.")
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
linear_method = quant_config.get_linear_method()
return linear_method
def _get_model_initialization_kwargs(
model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]
) -> Dict[str, Any]:
"""Get extra kwargs for model initialization."""
extra_kwargs = {}
if hasattr(model_class, "supported_lora_modules"):
extra_kwargs["lora_config"] = lora_config
elif lora_config:
raise ValueError(
f"Model {model_class.__name__} does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
elif model_class in _VISION_MODEL_CLASSES:
extra_kwargs["vision_language_config"] = vision_language_config
return extra_kwargs
def _initialize_model(
model_config: ModelConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
"""Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0]
linear_method = _get_linear_method(model_config, load_config)
return model_class(config=model_config.hf_config,
linear_method=linear_method,
**_get_model_initialization_kwargs(
model_class, lora_config, vision_language_config))
class BaseModelLoader(ABC):
"""Base class for model loaders."""
def __init__(self, load_config: LoadConfig):
self.load_config = load_config
@abstractmethod
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
"""Load a model with the given configurations."""
...
class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def _maybe_download_from_modelscope(
self, model: str, revision: Optional[str]) -> Optional[str]:
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if VLLM_USE_MODELSCOPE:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from modelscope.hub.snapshot_download import snapshot_download
if not os.path.exists(model):
model_path = snapshot_download(
model_id=model,
cache_dir=self.load_config.download_dir,
revision=revision)
else:
model_path = model
return model_path
return None
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str],
fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
model_name_or_path = self._maybe_download_from_modelscope(
model_name_or_path, revision) or model_name_or_path
is_local = os.path.isdir(model_name_or_path)
load_format = self.load_config.load_format
use_safetensors = False
# Some quantized models use .pt files for storing the weights.
if load_format == LoadFormat.AUTO:
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == LoadFormat.SAFETENSORS:
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == LoadFormat.PT:
allow_patterns = ["*.pt"]
elif load_format == LoadFormat.NPCACHE:
allow_patterns = ["*.bin"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
if fall_back_to_pt:
allow_patterns += ["*.pt"]
if not is_local:
hf_folder = download_weights_from_hf(model_name_or_path,
self.load_config.download_dir,
allow_patterns)
else:
hf_folder = model_name_or_path
hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if not use_safetensors:
hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
return hf_folder, hf_weights_files, use_safetensors
def _get_weights_iterator(
self, model_name_or_path: str, revision: Optional[str],
fall_back_to_pt: bool
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision, fall_back_to_pt)
if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
return np_cache_weights_iterator(model_name_or_path,
self.load_config.download_dir,
hf_folder, hf_weights_files)
if use_safetensors:
return safetensors_weights_iterator(hf_weights_files)
return pt_weights_iterator(hf_weights_files)
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config)
model.load_weights(
self._get_weights_iterator(model_config.model,
model_config.revision,
fall_back_to_pt=getattr(
model,
"fall_back_to_pt_during_load",
True)), )
return model.eval()
class DummyModelLoader(BaseModelLoader):
"""Model loader that will set model weights to random values."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
return model.eval()
class TensorizerLoader(BaseModelLoader):
"""Model loader using CoreWeave's tensorizer library."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
self.tensorizer_config = load_config.model_loader_extra_config
else:
self.tensorizer_config = TensorizerConfig(
**load_config.model_loader_extra_config)
def _verify_config(self, model_config: ModelConfig,
parallel_config: ParallelConfig):
self.tensorizer_config.verify_with_model_config(model_config)
self.tensorizer_config.verify_with_parallel_config(parallel_config)
def _get_weights_iterator(
self) -> Generator[Tuple[str, torch.Tensor], None, None]:
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
return tensorizer_weights_iterator(tensorizer_args)
def _load_model_unserialized(
self, model_config: ModelConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]
) -> nn.Module:
"""Load an unserialized model with tensorizer.
Unserialized here means "not serialized with tensorizer". This
should still be faster than default HuggingFace loading, but will
be slower than loading a tensorizer-serialized model.
"""
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config)
model.load_weights(self._get_weights_iterator())
return model.eval()
def _load_model_serialized(
self, model_config: ModelConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]
) -> nn.Module:
"""Load a serialized model with tensorizer.
See the examples/tensorize_vllm_model.py example "
script for serializing vLLM models."""
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0]
linear_method = _get_linear_method(model_config,
self.load_config)
extra_kwargs = _get_model_initialization_kwargs(
model_class, lora_config, vision_language_config)
extra_kwargs["linear_method"] = linear_method
tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class
tensorizer_config.hf_config = model_config.hf_config
tensorizer_config.dtype = model_config.dtype
model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
return model.eval()
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
self._verify_config(model_config, parallel_config)
if is_vllm_serialized_tensorizer(self.tensorizer_config):
return self._load_model_serialized(model_config, device_config,
lora_config,
vision_language_config)
return self._load_model_unserialized(model_config, device_config,
lora_config,
vision_language_config)
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
if isinstance(load_config.load_format, type):
return load_config.load_format(load_config)
if load_config.load_format == LoadFormat.DUMMY:
return DummyModelLoader(load_config)
if load_config.load_format == LoadFormat.TENSORIZER:
return TensorizerLoader(load_config)
return DefaultModelLoader(load_config)
......@@ -4,20 +4,20 @@ import io
import os
import time
import typing
import warnings
from dataclasses import dataclass
from typing import Optional, Union
from typing import Generator, Optional, Tuple, Type, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.config import TensorizerConfig
from vllm.config import ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
tensorizer_load_fail = False
tensorizer_load_fail = None
try:
from tensorizer import (DecryptionParams, EncryptionParams,
......@@ -25,51 +25,78 @@ try:
from tensorizer.stream_io import open_stream
from tensorizer.utils import (convert_bytes, get_mem_usage,
no_init_or_tensor)
except ImportError:
tensorizer_load_fail = True
except ImportError as e:
tensorizer_load_fail = e
__all__ = [
'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
'no_init_or_tensor'
'no_init_or_tensor', 'TensorizerConfig'
]
logger = init_logger(__name__)
@dataclass
class TensorizerConfig:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
str, bytes, os.PathLike, int]
vllm_tensorized: bool
verify_hash: Optional[bool] = False
num_readers: Optional[int] = 1
encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None
s3_endpoint: Optional[str] = None
model_class: Optional[Type[torch.nn.Module]] = None
hf_config: Optional[PretrainedConfig] = None
dtype: Optional[Union[str, torch.dtype]] = None
def _construct_tensorizer_args(self) -> "TensorizerArgs":
tensorizer_args = {
"tensorizer_uri": self.tensorizer_uri,
"vllm_tensorized": self.vllm_tensorized,
"verify_hash": self.verify_hash,
"num_readers": self.num_readers,
"encryption_keyfile": self.encryption_keyfile,
"s3_access_key_id": self.s3_access_key_id,
"s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint,
}
return TensorizerArgs(**tensorizer_args)
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
if (parallel_config.tensor_parallel_size > 1
and self.tensorizer_uri is not None):
raise ValueError(
"Loading to multiple GPUs is not currently supported with "
"vLLM-serialized models. Please set tensor_parallel_size=1."
" or use a non-vLLM-serialized model, such as a "
"serialized Hugging Face `PretrainedModel`.")
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
if (model_config.quantization is not None
and self.tensorizer_uri is not None):
logger.warning(
"Loading a model using Tensorizer with quantization on vLLM"
" is unstable and may lead to errors.")
def load_with_tensorizer(tensorizer_config: TensorizerConfig,
**extra_kwargs) -> nn.Module:
tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs)
return tensorizer.deserialize()
def tensorizer_warning(message: str):
return warnings.warn(message, category=PerformanceWarning, stacklevel=2)
def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool:
if tensorizer_config is None:
return False
return tensorizer_config.vllm_tensorized
class ParameterizedLoadFormat(str):
__slots__ = "params"
class PerformanceWarning(UserWarning):
def __str__(self):
return (f"{super().__str__()}"
" (set the VLLM_SILENCE_PERFORMANCE_WARNINGS"
" environment variable to hide this)")
if (os.getenv("VLLM_SILENCE_PERFORMANCE_WARNINGS", "").lower()
not in ("", "0", "n", "no", "off", "disable")):
warnings.simplefilter("ignore", category=PerformanceWarning)
@dataclass
class TensorizerArgs:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
......@@ -219,11 +246,17 @@ class TensorizerAgent:
behavior of the TensorDeserializer when loading tensors from a serialized
model. For deserializations of HuggingFace models, TensorDeserializer is
instead used as an iterator directly in the func hf_model_weights_iterator
in vllm/model_executor/weight_utils.py
in vllm/model_executor/model_loader/weight_utils.py
"""
def __init__(self, tensorizer_config: TensorizerConfig,
linear_method: LinearMethodBase, **extra_kwargs):
if tensorizer_load_fail is not None:
raise ImportError(
"Tensorizer is not installed. Please install tensorizer "
"to use this feature with `pip install vllm[tensorizer]`."
) from tensorizer_load_fail
self.tensorizer_config = tensorizer_config
self.tensorizer_args = (
self.tensorizer_config._construct_tensorizer_args())
......@@ -234,11 +267,6 @@ class TensorizerAgent:
self.linear_method = linear_method
self.model = self._init_model()
if tensorizer_load_fail:
raise ImportError(
"Tensorizer is not installed. Please install tensorizer "
"to use this feature with `pip install vllm[tensorizer]`.")
def _init_model(self):
model_args = self.tensorizer_config.hf_config
model_args.torch_dtype = self.tensorizer_config.dtype
......@@ -313,3 +341,23 @@ class TensorizerAgent:
self._check_tensors_on_meta_device()
self._resize_lora_embeddings()
return self.model.eval()
def tensorizer_weights_iterator(
tensorizer_args: "TensorizerArgs"
) -> Generator[Tuple[str, torch.Tensor], None, None]:
logger.warning(
"Deserializing HuggingFace models is not optimized for "
"loading on vLLM, as tensorizer is forced to load to CPU. "
"Consider deserializing a vLLM model instead for faster "
"load times. See the examples/tensorize_vllm_model.py example "
"script for serializing vLLM models.")
deserializer_args = tensorizer_args.deserializer_params
stream_params = tensorizer_args.stream_params
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
with TensorDeserializer(stream, **deserializer_args,
device="cpu") as state:
for name, param in state.items():
yield name, param
del state
"""Utilities for selecting and loading models."""
import contextlib
from typing import Tuple, Type
import torch
from torch import nn
from vllm.config import ModelConfig
from vllm.model_executor.models import ModelRegistry
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_architecture_class_name(model_config: ModelConfig) -> str:
return get_model_architecture(model_config)[1]
......@@ -4,8 +4,9 @@ import glob
import hashlib
import json
import os
import tempfile
from collections import defaultdict
from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union
from typing import Any, Generator, Iterable, List, Optional, Tuple
import filelock
import huggingface_hub.constants
......@@ -15,7 +16,7 @@ from huggingface_hub import HfFileSystem, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
from vllm.config import ModelConfig
from vllm.config import LoadConfig, ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config)
......@@ -27,8 +28,7 @@ logger = init_logger(__name__)
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir = os.environ.get('TMPDIR') or os.environ.get(
'TEMP') or os.environ.get('TMP') or "/tmp/"
temp_dir = tempfile.gettempdir()
def enable_hf_transfer():
......@@ -46,7 +46,7 @@ def enable_hf_transfer():
enable_hf_transfer()
class Disabledtqdm(tqdm):
class DisabledTqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
......@@ -114,7 +114,8 @@ def convert_bin_to_safetensor_file(
# TODO(woosuk): Move this to other place.
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
def get_quant_config(model_config: ModelConfig,
load_config: LoadConfig) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
......@@ -125,12 +126,12 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, model_config.download_dir):
with get_lock(model_name_or_path, load_config.download_dir):
hf_folder = snapshot_download(model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=model_config.download_dir,
tqdm_class=Disabledtqdm)
cache_dir=load_config.download_dir,
tqdm_class=DisabledTqdm)
else:
hf_folder = model_name_or_path
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
......@@ -153,36 +154,24 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
return quant_cls.from_config(config)
def prepare_hf_model_weights(
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
fall_back_to_pt: bool = True,
revision: Optional[str] = None,
) -> Tuple[str, List[str], bool]:
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path) \
and load_format != "tensorizer"
use_safetensors = False
# Some quantized models use .pt files for storing the weights.
if load_format == "auto":
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == "safetensors":
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == "pt":
allow_patterns = ["*.pt"]
elif load_format == "npcache":
allow_patterns = ["*.bin"]
elif load_format == "tensorizer":
allow_patterns = ["*.tensors"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
def download_weights_from_hf(model_name_or_path: str,
cache_dir: Optional[str],
allow_patterns: List[str],
revision: Optional[str] = None) -> str:
"""Download model weights from Hugging Face Hub.
if fall_back_to_pt:
allow_patterns += ["*.pt"]
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
allow_patterns (List[str]): The allowed patterns for the
weight files. Files matched by any of the patterns will be
downloaded.
revision (Optional[str]): The revision of the model.
if not is_local and load_format != "tensorizer":
Returns:
str: The path to the downloaded model weights.
"""
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
......@@ -201,20 +190,18 @@ def prepare_hf_model_weights(
hf_folder = snapshot_download(model_name_or_path,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
tqdm_class=Disabledtqdm,
tqdm_class=DisabledTqdm,
revision=revision)
else:
hf_folder = model_name_or_path
hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if not use_safetensors:
# Exclude files that are not needed for inference.
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
return hf_folder
def filter_files_not_needed_for_inference(
hf_weights_files: List[str]) -> List[str]:
"""
Exclude files that are not needed for inference.
See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
"""
blacklist = [
"training_args.bin",
"optimizer.bin",
......@@ -226,35 +213,17 @@ def prepare_hf_model_weights(
f for f in hf_weights_files
if not any(f.endswith(x) for x in blacklist)
]
return hf_weights_files
if load_format == "tensorizer":
return hf_folder, hf_weights_files, use_safetensors
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
return hf_folder, hf_weights_files, use_safetensors
def np_cache_weights_iterator(
model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model np files.
def hf_model_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: Union[Tuple, str] = "auto",
revision: Optional[str] = None,
fall_back_to_pt: Optional[bool] = True,
) -> Iterator[Tuple[str, torch.Tensor]]:
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
model_name_or_path,
cache_dir=cache_dir,
load_format=load_format,
fall_back_to_pt=fall_back_to_pt,
revision=revision)
if load_format == "npcache":
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
Will dump the model weights to numpy files if they are not already dumped.
"""
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, "np")
......@@ -283,33 +252,23 @@ def hf_model_weights_iterator(
with open(param_path, "rb") as f:
param = np.load(f)
yield name, torch.from_numpy(param)
elif load_format == "tensorizer":
from vllm.model_executor.tensorizer_loader import (TensorDeserializer,
open_stream,
tensorizer_warning)
tensorizer_args = load_format.params
tensorizer_warning(
"Deserializing HuggingFace models is not optimized for "
"loading on vLLM, as tensorizer is forced to load to CPU. "
"Consider deserializing a vLLM model instead for faster "
"load times. See the examples/tensorize_vllm_model.py example "
"script for serializing vLLM models.")
deserializer_args = tensorizer_args.deserializer_params
stream_params = tensorizer_args.stream_params
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
with TensorDeserializer(stream, **deserializer_args,
device="cpu") as state:
for name, param in state.items():
yield name, param
del state
elif use_safetensors:
def safetensors_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
else:
def pt_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
......
......@@ -19,7 +19,7 @@
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
import math
from typing import List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
......@@ -40,9 +40,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -340,19 +339,14 @@ class BaiChuanBaseForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if name == "lm_head.weight":
......
......@@ -17,7 +17,7 @@
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights."""
import math
from typing import List, Optional
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
......@@ -35,9 +35,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -298,14 +297,9 @@ class BloomForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
if name == "lm_head.weight":
continue
if not name.startswith("transformer."):
......
......@@ -2,7 +2,7 @@
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
from typing import List, Optional
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
......@@ -22,9 +22,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig
......@@ -370,14 +369,9 @@ class ChatGLMForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
if "rotary_pos_emb.inv_freq" in name:
continue
if "word_embeddings" in name:
......
......@@ -20,7 +20,7 @@
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
from typing import List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple
import torch
import torch.utils.checkpoint
......@@ -41,10 +41,9 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -335,13 +334,7 @@ class CohereForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -352,8 +345,7 @@ class CohereForCausalLM(nn.Module):
]
params_dict = dict(self.named_parameters())
loaded_params = set()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
......
# coding=utf-8
from typing import List, Optional
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
......@@ -18,10 +18,9 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.dbrx import DbrxConfig
......@@ -391,20 +390,13 @@ class DbrxForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = [(
"ws" if weight_name in ["w1", "v1"] else "w2s",
f"experts.mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
continue
......
......@@ -23,16 +23,15 @@
# limitations under the License.
"""Inference-only DeciLM model compatible with HuggingFace weights."""
from typing import Optional
from typing import Iterable, Optional, Tuple
import torch
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
class DeciLMForCausalLM(LlamaForCausalLM):
......@@ -65,11 +64,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
linear_method=linear_method,
lora_config=lora_config)
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -79,8 +74,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......
......@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
......@@ -44,9 +44,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -316,6 +315,8 @@ class DeepseekDecoderLayer(nn.Module):
class DeepseekModel(nn.Module):
fall_back_to_pt_during_load = False
def __init__(
self,
config: PretrainedConfig,
......@@ -395,11 +396,7 @@ class DeepseekForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -410,12 +407,7 @@ class DeepseekForCausalLM(nn.Module):
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path,
cache_dir,
load_format,
revision,
fall_back_to_pt=False):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
......
......@@ -19,7 +19,7 @@
"""PyTorch Falcon model."""
import math
from typing import List, Optional, Union
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -40,9 +40,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import RWConfig
......@@ -399,11 +398,7 @@ class FalconForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
total_num_heads = self.config.num_attention_heads
if self.config.new_decoder_architecture:
total_num_kv_heads = self.config.num_kv_heads
......@@ -413,8 +408,7 @@ class FalconForCausalLM(nn.Module):
total_num_kv_heads = total_num_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
if name == "lm_head.weight":
# Falcon uses tied embeddings.
continue
......
......@@ -15,7 +15,7 @@
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
from functools import lru_cache
from typing import List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
......@@ -36,9 +36,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
logger = init_logger(__name__)
......@@ -346,11 +345,7 @@ class GemmaForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -361,8 +356,7 @@ class GemmaForCausalLM(nn.Module):
]
params_dict = dict(self.named_parameters())
loaded_params = set()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
......
......@@ -17,7 +17,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
from typing import List, Optional
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
......@@ -34,9 +34,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -239,14 +238,9 @@ class GPT2LMHeadModel(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment