Unverified Commit 9760fd8f authored by 22quinn's avatar 22quinn Committed by GitHub
Browse files

[Core] Support inplace model weights loading (#18745)


Signed-off-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
parent b9f61e13
...@@ -4,7 +4,6 @@ import gc ...@@ -4,7 +4,6 @@ import gc
import os import os
import pathlib import pathlib
import subprocess import subprocess
from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
...@@ -16,7 +15,6 @@ from vllm.engine.arg_utils import EngineArgs ...@@ -16,7 +15,6 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
TensorSerializer, TensorSerializer,
is_vllm_tensorized, is_vllm_tensorized,
load_with_tensorizer,
open_stream, open_stream,
tensorize_vllm_model) tensorize_vllm_model)
# yapf: enable # yapf: enable
...@@ -61,21 +59,6 @@ def write_keyfile(keyfile_path: str): ...@@ -61,21 +59,6 @@ def write_keyfile(keyfile_path: str):
f.write(encryption_params.key) f.write(encryption_params.key)
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
def test_load_with_tensorizer(mock_agent, tensorizer_config):
mock_linear_method = MagicMock()
mock_agent_instance = mock_agent.return_value
mock_agent_instance.deserialize.return_value = MagicMock()
result = load_with_tensorizer(tensorizer_config,
quant_method=mock_linear_method)
mock_agent.assert_called_once_with(tensorizer_config,
quant_method=mock_linear_method)
mock_agent_instance.deserialize.assert_called_once()
assert result == mock_agent_instance.deserialize.return_value
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner): def test_can_deserialize_s3(vllm_runner):
model_ref = "EleutherAI/pythia-1.4b" model_ref = "EleutherAI/pythia-1.4b"
......
...@@ -94,6 +94,9 @@ def model_runner(): ...@@ -94,6 +94,9 @@ def model_runner():
return runner return runner
model_runner_2 = model_runner
def _schedule_new_request(*req_ids: str) -> SchedulerOutput: def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
new_reqs = [] new_reqs = []
num_scheduled_tokens = {} num_scheduled_tokens = {}
...@@ -366,3 +369,18 @@ def test_kv_cache_stride_order(monkeypatch, model_runner): ...@@ -366,3 +369,18 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
assert all(kv.is_contiguous() for kv in model_runner.kv_caches) assert all(kv.is_contiguous() for kv in model_runner.kv_caches)
else: else:
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
# In this test, model_runner loads model + weights in one go, while
# model_runner_2 loads dummy weights first then load real weights inplace
model_runner.load_model()
original_load_format = model_runner_2.load_config.load_format
model_runner_2.load_config.load_format = "dummy"
model_runner_2.load_model() # Initial model loading with dummy weights
assert str(model_runner.get_model().state_dict()) != str(
model_runner_2.get_model().state_dict())
model_runner_2.load_config.load_format = original_load_format
model_runner_2.load_model() # Load real weights inplace
assert str(model_runner.get_model().state_dict()) == str(
model_runner_2.get_model().state_dict())
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import LoadConfig, ModelConfig, VllmConfig from vllm.config import LoadConfig, ModelConfig, VllmConfig
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
class BaseModelLoader(ABC): class BaseModelLoader(ABC):
...@@ -18,7 +21,22 @@ class BaseModelLoader(ABC): ...@@ -18,7 +21,22 @@ class BaseModelLoader(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def load_model(self, *, vllm_config: VllmConfig, def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
"""Load weights into a model. This standalone API allows
inplace weights loading for an already-initialized model"""
raise NotImplementedError
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module: model_config: ModelConfig) -> nn.Module:
"""Load a model with the given configurations.""" """Load a model with the given configurations."""
raise NotImplementedError device_config = vllm_config.device_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config,
model_config=model_config)
# Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)
return model.eval()
...@@ -14,7 +14,7 @@ from huggingface_hub import HfApi ...@@ -14,7 +14,7 @@ from huggingface_hub import HfApi
from torch import nn from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import LoadConfig, ModelConfig, VllmConfig from vllm.config import LoadConfig, ModelConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
# yapf: enable # yapf: enable
...@@ -28,7 +28,6 @@ from vllm.model_executor.layers.linear import (LinearBase, ...@@ -28,7 +28,6 @@ from vllm.model_executor.layers.linear import (LinearBase,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (ParamMapping, from vllm.model_executor.model_loader.utils import (ParamMapping,
initialize_model,
set_default_torch_dtype) set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf, download_safetensors_index_file_from_hf, download_weights_from_hf,
...@@ -408,8 +407,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -408,8 +407,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
), "vllm currently does not support BNB quantization for" ), "vllm currently does not support BNB quantization for"
f" {type(model).__name__}" f" {type(model).__name__}"
def _load_weights(self, model_config: ModelConfig, def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
model: nn.Module) -> None:
if not hasattr(model, "load_weights"): if not hasattr(model, "load_weights"):
raise AttributeError( raise AttributeError(
"The required method 'load_weights' is not defined in class" "The required method 'load_weights' is not defined in class"
...@@ -568,15 +566,3 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -568,15 +566,3 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision) self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = initialize_model(vllm_config=vllm_config)
self._load_weights(model_config, model)
return model.eval()
...@@ -12,11 +12,9 @@ from torch import nn ...@@ -12,11 +12,9 @@ from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm import envs from vllm import envs
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig from vllm.config import LoadConfig, LoadFormat, ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf, download_safetensors_index_file_from_hf, download_weights_from_hf,
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
...@@ -264,32 +262,20 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -264,32 +262,20 @@ class DefaultModelLoader(BaseModelLoader):
fall_back_to_pt=True, fall_back_to_pt=True,
allow_patterns_overrides=None) allow_patterns_overrides=None)
def load_model(self, vllm_config: VllmConfig, def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> nn.Module: model_config: ModelConfig) -> None:
device_config = vllm_config.device_config weights_to_load = {name for name, _ in model.named_parameters()}
target_device = torch.device(device_config.device) loaded_weights = model.load_weights(
with set_default_torch_dtype(model_config.dtype): self.get_all_weights(model_config, model))
with target_device: self.counter_after_loading_weights = time.perf_counter()
model = initialize_model(vllm_config=vllm_config, logger.info(
model_config=model_config) "Loading weights took %.2f seconds",
self.counter_after_loading_weights -
weights_to_load = {name for name, _ in model.named_parameters()} self.counter_before_loading_weights)
loaded_weights = model.load_weights( # We only enable strict check for non-quantized models
self.get_all_weights(model_config, model)) # that have loaded weights tracking currently.
self.counter_after_loading_weights = time.perf_counter() if model_config.quantization is None and loaded_weights is not None:
logger.info( weights_not_loaded = weights_to_load - loaded_weights
"Loading weights took %.2f seconds", if weights_not_loaded:
self.counter_after_loading_weights - raise ValueError("Following weights were not initialized from "
self.counter_before_loading_weights) f"checkpoint: {weights_not_loaded}")
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
process_weights_after_loading(model, model_config, target_device)
return model.eval()
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import LoadConfig, ModelConfig, VllmConfig from vllm.config import LoadConfig, ModelConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
initialize_dummy_weights) initialize_dummy_weights)
...@@ -22,16 +19,8 @@ class DummyModelLoader(BaseModelLoader): ...@@ -22,16 +19,8 @@ class DummyModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download pass # Nothing to download
def load_model(self, vllm_config: VllmConfig, def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> nn.Module: model_config: ModelConfig) -> None:
device_config = vllm_config.device_config # NOTE(woosuk): For accurate performance evaluation, we assign
target_device = torch.device(device_config.device) # random values to the weights.
with set_default_torch_dtype(model_config.dtype): initialize_dummy_weights(model)
with target_device:
model = initialize_model(vllm_config=vllm_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
process_weights_after_loading(model, model_config, target_device)
return model.eval()
...@@ -92,6 +92,13 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -92,6 +92,13 @@ class GGUFModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model) self._prepare_weights(model_config.model)
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map))
def load_model(self, vllm_config: VllmConfig, def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module: model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config device_config = vllm_config.device_config
...@@ -106,8 +113,7 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -106,8 +113,7 @@ class GGUFModelLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with target_device: with target_device:
model = initialize_model(vllm_config=vllm_config) model = initialize_model(vllm_config=vllm_config)
model.load_weights( self.load_weights(model, model_config)
self._get_weights_iterator(local_model_path, gguf_weights_map))
process_weights_after_loading(model, model_config, target_device) process_weights_after_loading(model, model_config, target_device)
return model return model
...@@ -9,10 +9,8 @@ import torch ...@@ -9,10 +9,8 @@ import torch
from torch import nn from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import LoadConfig, ModelConfig, VllmConfig from vllm.config import LoadConfig, ModelConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf, download_safetensors_index_file_from_hf, download_weights_from_hf,
runai_safetensors_weights_iterator) runai_safetensors_weights_iterator)
...@@ -100,21 +98,11 @@ class RunaiModelStreamerLoader(BaseModelLoader): ...@@ -100,21 +98,11 @@ class RunaiModelStreamerLoader(BaseModelLoader):
"""Download model if necessary""" """Download model if necessary"""
self._prepare_weights(model_config.model, model_config.revision) self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig, def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> nn.Module: model_config: ModelConfig) -> None:
"""Perform streaming of the model to destination""" """Load weights into a model."""
device_config = vllm_config.device_config model_weights = model_config.model
target_device = torch.device(device_config.device) if hasattr(model_config, "model_weights"):
with set_default_torch_dtype(model_config.dtype): model_weights = model_config.model_weights
with target_device: model.load_weights(
model = initialize_model(vllm_config=vllm_config) self._get_weights_iterator(model_weights, model_config.revision))
model_weights = model_config.model
if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
model.load_weights(
self._get_weights_iterator(model_weights,
model_config.revision))
process_weights_after_loading(model, model_config, target_device)
return model.eval()
...@@ -9,11 +9,9 @@ from typing import Any, Optional ...@@ -9,11 +9,9 @@ from typing import Any, Optional
import torch import torch
from torch import nn from torch import nn
from vllm.config import LoadConfig, ModelConfig, VllmConfig from vllm.config import LoadConfig, ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, runai_safetensors_weights_iterator) download_weights_from_hf, runai_safetensors_weights_iterator)
from vllm.transformers_utils.s3_utils import glob as s3_glob from vllm.transformers_utils.s3_utils import glob as s3_glob
...@@ -100,11 +98,8 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -100,11 +98,8 @@ class ShardedStateLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision) self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig, def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> nn.Module: model_config: ModelConfig) -> None:
device_config = vllm_config.device_config
target_device = torch.device(device_config.device)
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
model_weights = model_config.model model_weights = model_config.model
...@@ -112,53 +107,47 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -112,53 +107,47 @@ class ShardedStateLoader(BaseModelLoader):
model_weights = model_config.model_weights model_weights = model_config.model_weights
local_model_path = model_weights local_model_path = model_weights
with set_default_torch_dtype(model_config.dtype): rank = get_tensor_model_parallel_rank()
with target_device: pattern = os.path.join(
model = initialize_model(vllm_config=vllm_config) local_model_path,
process_weights_after_loading(model, model_config, self.pattern.format(rank=rank, part="*"),
target_device) )
rank = get_tensor_model_parallel_rank()
pattern = os.path.join( filepaths = []
local_model_path, if is_s3(local_model_path):
self.pattern.format(rank=rank, part="*"), file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}"
) filepaths = s3_glob(path=local_model_path,
allow_pattern=[file_pattern])
filepaths = [] else:
if is_s3(local_model_path): filepaths = glob.glob(pattern)
file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}" if not filepaths:
filepaths = s3_glob(path=local_model_path, # TODO: support un-sharded checkpoints too
allow_pattern=[file_pattern]) raise ValueError(
else: f"Could not find checkpoint files '{pattern}', only "
filepaths = glob.glob(pattern) f"pre-sharded checkpoints are currently supported!")
if not filepaths: state_dict = self._filter_subtensors(model.state_dict())
# TODO: support un-sharded checkpoints too for key, tensor in self.iterate_over_files(filepaths):
raise ValueError( # If loading with LoRA enabled, additional padding may
f"Could not find checkpoint files '{pattern}', only " # be added to certain parameters. We only load into a
f"pre-sharded checkpoints are currently supported!") # narrowed view of the parameter data.
state_dict = self._filter_subtensors(model.state_dict()) param_data = state_dict[key].data
for key, tensor in self.iterate_over_files(filepaths): param_shape = state_dict[key].shape
# If loading with LoRA enabled, additional padding may for dim, size in enumerate(tensor.shape):
# be added to certain parameters. We only load into a if size < param_shape[dim]:
# narrowed view of the parameter data. param_data = param_data.narrow(dim, 0, size)
param_data = state_dict[key].data if tensor.shape != param_shape:
param_shape = state_dict[key].shape logger.warning(
for dim, size in enumerate(tensor.shape): "loading tensor of shape %s into "
if size < param_shape[dim]: "parameter '%s' of shape %s",
param_data = param_data.narrow(dim, 0, size) tensor.shape,
if tensor.shape != param_shape: key,
logger.warning( param_shape,
"loading tensor of shape %s into " )
"parameter '%s' of shape %s", param_data.copy_(tensor)
tensor.shape, state_dict.pop(key)
key, if state_dict:
param_shape, raise ValueError(
) f"Missing keys {tuple(state_dict)} in loaded state!")
param_data.copy_(tensor)
state_dict.pop(key)
if state_dict:
raise ValueError(
f"Missing keys {tuple(state_dict)} in loaded state!")
return model.eval()
def iterate_over_files( def iterate_over_files(
self, paths) -> Generator[tuple[str, torch.Tensor], None, None]: self, paths) -> Generator[tuple[str, torch.Tensor], None, None]:
......
...@@ -21,7 +21,8 @@ from torch.utils._python_dispatch import TorchDispatchMode ...@@ -21,7 +21,8 @@ from torch.utils._python_dispatch import TorchDispatchMode
from transformers import PretrainedConfig from transformers import PretrainedConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ModelConfig, ParallelConfig, set_current_vllm_config from vllm.config import (ModelConfig, ParallelConfig, VllmConfig,
set_current_vllm_config)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -208,12 +209,6 @@ class TensorizerConfig: ...@@ -208,12 +209,6 @@ class TensorizerConfig:
**tensorizer_args.stream_params) **tensorizer_args.stream_params)
def load_with_tensorizer(tensorizer_config: TensorizerConfig,
**extra_kwargs) -> nn.Module:
tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs)
return tensorizer.deserialize()
@dataclass @dataclass
class TensorizerArgs: class TensorizerArgs:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str, tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
...@@ -366,100 +361,72 @@ class TensorizerArgs: ...@@ -366,100 +361,72 @@ class TensorizerArgs:
return tensorizer_args return tensorizer_args
class TensorizerAgent: def _check_tensors_on_meta_device(model: nn.Module) -> None:
""" for tensor in model.state_dict().values():
A class for performing tensorizer deserializations specifically for if tensor.device.type == 'meta':
vLLM models using plaid_mode. Uses TensorizerArgs to configure the raise ValueError(
behavior of the TensorDeserializer when loading tensors from a serialized "The serialized model contains tensors on the meta device,"
model. For deserializations of HuggingFace models, TensorDeserializer is " indicating that some tensors were not loaded properly."
instead used as an iterator directly in the func hf_model_weights_iterator " Please check that the parameters of the model being"
in vllm/model_executor/model_loader/weight_utils.py " specified match that of the serialized model, such as"
""" " its quantization.")
def __init__(self, tensorizer_config: TensorizerConfig, vllm_config):
self.tensorizer_config = tensorizer_config def _resize_lora_embeddings(model: nn.Module):
self.tensorizer_args = ( """Modify LoRA embedding layers to use bigger tensors
self.tensorizer_config._construct_tensorizer_args()) to allow for adapter added tokens."""
self.vllm_config = vllm_config for child in model.modules():
self.model = self._init_model() if (isinstance(child, VocabParallelEmbedding) and child.weight.shape[0]
< child.num_embeddings_per_partition):
def _init_model(self): new_weight = torch.empty(child.num_embeddings_per_partition,
assert self.tensorizer_config.hf_config is not None child.embedding_dim,
model_args = self.tensorizer_config.hf_config dtype=child.weight.dtype,
model_args.torch_dtype = self.tensorizer_config.dtype device=child.weight.device)
assert self.tensorizer_config.model_class is not None new_weight[:child.weight.shape[0]].copy_(child.weight.data)
# TODO: Do we need to consider old-style model class? new_weight[child.weight.shape[0]:].fill_(0)
with meta_tensor_mode(), set_current_vllm_config(self.vllm_config, child.weight.data = new_weight
check_compile=True):
return self.tensorizer_config.model_class(
vllm_config=self.vllm_config) def init_tensorizer_model(tensorizer_config: TensorizerConfig,
vllm_config: VllmConfig) -> nn.Module:
def _resize_lora_embeddings(self): assert tensorizer_config.hf_config is not None
"""Modify LoRA embedding layers to use bigger tensors model_args = tensorizer_config.hf_config
to allow for adapter added tokens.""" model_args.torch_dtype = tensorizer_config.dtype
for child in self.model.modules(): assert tensorizer_config.model_class is not None
if (isinstance(child, VocabParallelEmbedding) # TODO: Do we need to consider old-style model class?
and child.weight.shape[0] with meta_tensor_mode(), set_current_vllm_config(vllm_config,
< child.num_embeddings_per_partition): check_compile=True):
new_weight = torch.empty(child.num_embeddings_per_partition, return tensorizer_config.model_class(vllm_config=vllm_config)
child.embedding_dim,
dtype=child.weight.dtype,
device=child.weight.device) def deserialize_tensorizer_model(model: nn.Module,
new_weight[:child.weight.shape[0]].copy_(child.weight.data) tensorizer_config: TensorizerConfig) -> None:
new_weight[child.weight.shape[0]:].fill_(0) tensorizer_args = tensorizer_config._construct_tensorizer_args()
child.weight.data = new_weight before_mem = get_mem_usage()
start = time.perf_counter()
def _check_tensors_on_meta_device(self): with _read_stream(
for tensor in self.model.state_dict().values(): tensorizer_config.tensorizer_uri,
if tensor.device.type == 'meta': **tensorizer_args.stream_params) as stream, TensorDeserializer(
raise ValueError(
"The serialized model contains tensors on the meta device,"
" indicating that some tensors were not loaded properly."
" Please check that the parameters of the model being"
" specified match that of the serialized model, such as"
" its quantization.")
def deserialize(self):
"""
Deserialize the model using the TensorDeserializer. This method is
specifically for vLLM models using tensorizer's plaid_mode.
The deserializer makes use of tensorizer_args.stream_params
to configure the behavior of the stream when loading tensors from a
serialized model. The deserializer_params are used to configure the
behavior of the TensorDeserializer when loading tensors themselves.
Documentation on these params can be found in TensorizerArgs
Returns:
nn.Module: The deserialized model.
"""
before_mem = get_mem_usage()
start = time.perf_counter()
with _read_stream(
self.tensorizer_config.tensorizer_uri,
**self.tensorizer_args.stream_params
) as stream, TensorDeserializer(
stream, stream,
dtype=self.tensorizer_config.dtype, dtype=tensorizer_config.dtype,
device=f'cuda:{torch.cuda.current_device()}', device=f'cuda:{torch.cuda.current_device()}',
**self.tensorizer_args.deserializer_params) as deserializer: **tensorizer_args.deserializer_params) as deserializer:
deserializer.load_into_module(self.model) deserializer.load_into_module(model)
end = time.perf_counter() end = time.perf_counter()
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
duration = end - start duration = end - start
per_second = convert_bytes(deserializer.total_tensor_bytes / duration) per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage() after_mem = get_mem_usage()
deserializer.close() deserializer.close()
logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str, logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str,
end - start, per_second) end - start, per_second)
logger.info("Memory usage before: %s", before_mem) logger.info("Memory usage before: %s", before_mem)
logger.info("Memory usage after: %s", after_mem) logger.info("Memory usage after: %s", after_mem)
self._check_tensors_on_meta_device() _check_tensors_on_meta_device(model)
self._resize_lora_embeddings() _resize_lora_embeddings(model)
del self.model.vllm_tensorized_marker del model.vllm_tensorized_marker
return self.model.eval()
def tensorizer_weights_iterator( def tensorizer_weights_iterator(
......
...@@ -11,8 +11,8 @@ from vllm.config import LoadConfig, ModelConfig, ParallelConfig, VllmConfig ...@@ -11,8 +11,8 @@ from vllm.config import LoadConfig, ModelConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.tensorizer import ( from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, TensorizerConfig, deserialize_tensorizer_model, init_tensorizer_model,
serialize_vllm_model, tensorizer_weights_iterator) is_vllm_tensorized, serialize_vllm_model, tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (get_model_architecture, from vllm.model_executor.model_loader.utils import (get_model_architecture,
initialize_model, initialize_model,
set_default_torch_dtype) set_default_torch_dtype)
...@@ -61,38 +61,34 @@ class TensorizerLoader(BaseModelLoader): ...@@ -61,38 +61,34 @@ class TensorizerLoader(BaseModelLoader):
model.load_weights(self._get_weights_iterator()) model.load_weights(self._get_weights_iterator())
return model.eval() return model.eval()
def _load_model_serialized(
self,
vllm_config: VllmConfig,
) -> nn.Module:
"""Load a serialized model with tensorizer.
Expects a vLLM-tensorized model. See the
examples/others/tensorize_vllm_model.py example script
for serializing vLLM models."""
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0]
tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class
tensorizer_config.hf_config = model_config.hf_config
tensorizer_config.dtype = model_config.dtype
model = load_with_tensorizer(tensorizer_config,
vllm_config=vllm_config)
return model.eval()
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self.tensorizer_config.verify_with_model_config(model_config) self.tensorizer_config.verify_with_model_config(model_config)
with self.tensorizer_config.open_stream(): with self.tensorizer_config.open_stream():
pass pass
def _patch_tensorizer_config(
self, model_config: ModelConfig) -> TensorizerConfig:
model_class = get_model_architecture(model_config)[0]
tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class
tensorizer_config.hf_config = model_config.hf_config
tensorizer_config.dtype = model_config.dtype
return tensorizer_config
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
"""Load serialized model weights with tensorizer.
Expects a vLLM-tensorized model. See the
examples/others/tensorize_vllm_model.py example script
for serializing vLLM models."""
if is_vllm_tensorized(self.tensorizer_config):
tensorizer_config = self._patch_tensorizer_config(model_config)
deserialize_tensorizer_model(model, tensorizer_config)
else:
model.load_weights(self._get_weights_iterator())
def load_model(self, vllm_config: VllmConfig, def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module: model_config: ModelConfig) -> nn.Module:
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
...@@ -106,7 +102,11 @@ class TensorizerLoader(BaseModelLoader): ...@@ -106,7 +102,11 @@ class TensorizerLoader(BaseModelLoader):
get_tensor_model_parallel_rank()) get_tensor_model_parallel_rank())
if is_vllm_tensorized(self.tensorizer_config): if is_vllm_tensorized(self.tensorizer_config):
return self._load_model_serialized(vllm_config=vllm_config) tensorizer_config = self._patch_tensorizer_config(model_config)
model = init_tensorizer_model(tensorizer_config=tensorizer_config,
vllm_config=vllm_config)
self.load_weights(model, model_config)
return model
return self._load_model_serialized_cpu(vllm_config=vllm_config) return self._load_model_serialized_cpu(vllm_config=vllm_config)
@staticmethod @staticmethod
......
...@@ -28,7 +28,7 @@ from vllm.forward_context import (DPMetadata, get_forward_context, ...@@ -28,7 +28,7 @@ from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context) set_forward_context)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
...@@ -1564,7 +1564,18 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1564,7 +1564,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117 with DeviceMemoryProfiler() as m: # noqa: SIM117
time_before_load = time.perf_counter() time_before_load = time.perf_counter()
self.model = get_model(vllm_config=self.vllm_config) model_loader = get_model_loader(self.load_config)
if not hasattr(self, "model"):
logger.info("Loading model from scratch...")
self.model = model_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.model_config)
else:
logger.info(
"Model was already initialized. Loading weights inplace..."
)
model_loader.load_weights(self.model,
model_config=self.model_config)
if self.lora_config: if self.lora_config:
self.model = self.load_lora_model(self.model, self.model = self.load_lora_model(self.model,
self.model_config, self.model_config,
......
...@@ -21,7 +21,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config ...@@ -21,7 +21,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
PlaceholderRange) PlaceholderRange)
...@@ -171,7 +171,7 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -171,7 +171,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.encoder_cache_size = encoder_cache_size self.encoder_cache_size = encoder_cache_size
# Lazy initialization # Lazy initialization
# self.model: nn.Module # Set after load_model self.model: nn.Module # Set after load_model
self.kv_caches: list[torch.Tensor] = [] self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output) # req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
...@@ -419,7 +419,6 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -419,7 +419,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
assert self.model is not None
return self.model return self.model
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
...@@ -936,7 +935,18 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -936,7 +935,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
"vllm.model_executor.layers.vocab_parallel_embedding." "vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank", "get_tensor_model_parallel_rank",
return_value=xm_tp_rank): return_value=xm_tp_rank):
model = get_model(vllm_config=self.vllm_config) # model = get_model(vllm_config=self.vllm_config)
model_loader = get_model_loader(self.load_config)
if not hasattr(self, "model"):
logger.info("Loading model from scratch...")
model = model_loader.load_model(vllm_config=self.vllm_config,
model_config=self.model_config)
else:
logger.info(
"Model was already initialized. Loading weights inplace..."
)
model_loader.load_weights(self.model,
model_config=self.model_config)
if self.lora_config is not None: if self.lora_config is not None:
model = self.load_lora_model(model, self.model_config, model = self.load_lora_model(model, self.model_config,
self.scheduler_config, self.scheduler_config,
...@@ -947,7 +957,8 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -947,7 +957,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
# loading. # loading.
xm.mark_step() xm.mark_step()
xm.wait_device_ops() xm.wait_device_ops()
self.model = model if not hasattr(self, "model"):
self.model = model
self.sampler = TPUSampler() self.sampler = TPUSampler()
@torch.no_grad() @torch.no_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment