Unverified Commit 69e1d2fb authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

[Core] Refactor model loading code (#4097)

parent 05434764
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Xverse model compatible with HuggingFace weights.""" """Inference-only Xverse model compatible with HuggingFace weights."""
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -40,9 +40,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -40,9 +40,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -331,11 +330,7 @@ class XverseForCausalLM(nn.Module): ...@@ -331,11 +330,7 @@ class XverseForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [ stacked_params_mapping = [
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
...@@ -344,8 +339,7 @@ class XverseForCausalLM(nn.Module): ...@@ -344,8 +339,7 @@ class XverseForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in weights:
model_name_or_path, cache_dir, load_format, revision):
if ("rotary_emb.inv_freq" in name if ("rotary_emb.inv_freq" in name
or "rotary_emb.cos_cached" in name or "rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name): or "rotary_emb.sin_cached" in name):
......
import os
from typing import Optional, Union from typing import Optional, Union
from transformers import (AutoTokenizer, PreTrainedTokenizer, from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast) PreTrainedTokenizerFast)
from vllm.config import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizers import BaichuanTokenizer from vllm.transformers_utils.tokenizers import BaichuanTokenizer
...@@ -57,9 +59,26 @@ def get_tokenizer( ...@@ -57,9 +59,26 @@ def get_tokenizer(
tokenizer_mode: str = "auto", tokenizer_mode: str = "auto",
trust_remote_code: bool = False, trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
download_dir: Optional[str] = None,
**kwargs, **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface.""" """Gets a tokenizer for the given model name via Huggingface/modelscope."""
if VLLM_USE_MODELSCOPE:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from modelscope.hub.snapshot_download import snapshot_download
# Only set the tokenizer here, model will be downloaded on the workers.
if not os.path.exists(tokenizer_name):
tokenizer_path = snapshot_download(
model_id=tokenizer_name,
cache_dir=download_dir,
revision=tokenizer_revision,
# Ignore weights - we only need the tokenizer.
ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"])
tokenizer_name = tokenizer_path
if tokenizer_mode == "slow": if tokenizer_mode == "slow":
if kwargs.get("use_fast", False): if kwargs.get("use_fast", False):
raise ValueError( raise ValueError(
......
...@@ -3,8 +3,8 @@ from typing import Dict, List, Optional, Tuple ...@@ -3,8 +3,8 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
SchedulerConfig) ParallelConfig, SchedulerConfig)
from vllm.distributed import broadcast_tensor_dict from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
...@@ -26,6 +26,7 @@ class CPUModelRunner: ...@@ -26,6 +26,7 @@ class CPUModelRunner:
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
...@@ -36,6 +37,7 @@ class CPUModelRunner: ...@@ -36,6 +37,7 @@ class CPUModelRunner:
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.lora_config = lora_config self.lora_config = lora_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py. # model_config can be None in tests/samplers/test_sampler.py.
...@@ -55,8 +57,10 @@ class CPUModelRunner: ...@@ -55,8 +57,10 @@ class CPUModelRunner:
self.model_config.dtype if model_config is not None else None) self.model_config.dtype if model_config is not None else None)
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(self.model_config, self.model = get_model(model_config=self.model_config,
self.device_config, load_config=self.load_config,
device_config=self.device_config,
vision_language_config=None,
lora_config=self.lora_config, lora_config=self.lora_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config) scheduler_config=self.scheduler_config)
......
...@@ -5,8 +5,8 @@ import torch ...@@ -5,8 +5,8 @@ import torch
import torch.distributed import torch.distributed
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ParallelConfig, SchedulerConfig) ModelConfig, ParallelConfig, SchedulerConfig)
from vllm.distributed import (broadcast_tensor_dict, from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized, ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
...@@ -117,6 +117,7 @@ class CPUWorker(LoraNotSupportedWorkerBase): ...@@ -117,6 +117,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
...@@ -129,6 +130,7 @@ class CPUWorker(LoraNotSupportedWorkerBase): ...@@ -129,6 +130,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.cache_config = cache_config self.cache_config = cache_config
self.load_config = load_config
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
...@@ -141,6 +143,7 @@ class CPUWorker(LoraNotSupportedWorkerBase): ...@@ -141,6 +143,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
parallel_config, parallel_config,
scheduler_config, scheduler_config,
device_config, device_config,
load_config=self.load_config,
lora_config=self.lora_config, lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker)
......
...@@ -9,9 +9,8 @@ import torch.nn as nn ...@@ -9,9 +9,8 @@ import torch.nn as nn
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend) get_attn_backend)
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
SchedulerConfig, TensorizerConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig)
VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
from vllm.distributed.device_communicators import (custom_all_reduce, from vllm.distributed.device_communicators import (custom_all_reduce,
pynccl_utils) pynccl_utils)
...@@ -108,17 +107,17 @@ class ModelRunner: ...@@ -108,17 +107,17 @@ class ModelRunner:
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
vision_language_config: Optional[VisionLanguageConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None,
tensorizer_config: Optional[TensorizerConfig] = None,
): ):
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.lora_config = lora_config self.lora_config = lora_config
self.tensorizer_config = tensorizer_config self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py. # model_config can be None in tests/samplers/test_sampler.py.
...@@ -156,13 +155,13 @@ class ModelRunner: ...@@ -156,13 +155,13 @@ class ModelRunner:
def load_model(self) -> None: def load_model(self) -> None:
with CudaMemoryProfiler() as m: with CudaMemoryProfiler() as m:
self.model = get_model( self.model = get_model(
self.model_config, model_config=self.model_config,
self.device_config, device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config, vision_language_config=self.vision_language_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
tensorizer_config=self.tensorizer_config,
) )
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
......
...@@ -6,7 +6,7 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, ...@@ -6,7 +6,7 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.neuron_model_loader import get_neuron_model from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import (async_tensor_h2d, is_pin_memory_available, from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
......
...@@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Set, Tuple ...@@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Set, Tuple
import torch import torch
import torch.distributed import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ParallelConfig, SchedulerConfig, TensorizerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict, from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized, ensure_model_parallel_initialized,
...@@ -38,12 +38,12 @@ class Worker(WorkerBase): ...@@ -38,12 +38,12 @@ class Worker(WorkerBase):
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None,
tensorizer_config: Optional[TensorizerConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
...@@ -55,7 +55,7 @@ class Worker(WorkerBase): ...@@ -55,7 +55,7 @@ class Worker(WorkerBase):
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.lora_config = lora_config self.lora_config = lora_config
self.tensorizer_config = tensorizer_config self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert self.rank == 0, "The driver worker must have rank 0."
...@@ -70,11 +70,11 @@ class Worker(WorkerBase): ...@@ -70,11 +70,11 @@ class Worker(WorkerBase):
parallel_config, parallel_config,
scheduler_config, scheduler_config,
device_config, device_config,
load_config=load_config,
lora_config=self.lora_config, lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config, vision_language_config=vision_language_config,
tensorizer_config=tensorizer_config,
) )
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment