Unverified Commit 69e1d2fb authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

[Core] Refactor model loading code (#4097)

parent 05434764
......@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Xverse model compatible with HuggingFace weights."""
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
......@@ -40,9 +40,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -331,11 +330,7 @@ class XverseForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
......@@ -344,8 +339,7 @@ class XverseForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
if ("rotary_emb.inv_freq" in name
or "rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
......
import os
from typing import Optional, Union
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from vllm.config import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
......@@ -57,9 +59,26 @@ def get_tokenizer(
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None,
download_dir: Optional[str] = None,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
"""Gets a tokenizer for the given model name via Huggingface/modelscope."""
if VLLM_USE_MODELSCOPE:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from modelscope.hub.snapshot_download import snapshot_download
# Only set the tokenizer here, model will be downloaded on the workers.
if not os.path.exists(tokenizer_name):
tokenizer_path = snapshot_download(
model_id=tokenizer_name,
cache_dir=download_dir,
revision=tokenizer_revision,
# Ignore weights - we only need the tokenizer.
ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"])
tokenizer_name = tokenizer_path
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError(
......
......@@ -3,8 +3,8 @@ from typing import Dict, List, Optional, Tuple
import torch
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
......@@ -26,6 +26,7 @@ class CPUModelRunner:
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
......@@ -36,6 +37,7 @@ class CPUModelRunner:
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.lora_config = lora_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py.
......@@ -55,8 +57,10 @@ class CPUModelRunner:
self.model_config.dtype if model_config is not None else None)
def load_model(self) -> None:
self.model = get_model(self.model_config,
self.device_config,
self.model = get_model(model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
vision_language_config=None,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
......
......@@ -5,8 +5,8 @@ import torch
import torch.distributed
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
......@@ -117,6 +117,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
......@@ -129,6 +130,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
......@@ -141,6 +143,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
parallel_config,
scheduler_config,
device_config,
load_config=self.load_config,
lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
......
......@@ -9,9 +9,8 @@ import torch.nn as nn
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend)
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, TensorizerConfig,
VisionLanguageConfig)
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
from vllm.distributed.device_communicators import (custom_all_reduce,
pynccl_utils)
......@@ -108,17 +107,17 @@ class ModelRunner:
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
vision_language_config: Optional[VisionLanguageConfig] = None,
tensorizer_config: Optional[TensorizerConfig] = None,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.lora_config = lora_config
self.tensorizer_config = tensorizer_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py.
......@@ -156,13 +155,13 @@ class ModelRunner:
def load_model(self) -> None:
with CudaMemoryProfiler() as m:
self.model = get_model(
self.model_config,
self.device_config,
model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
tensorizer_config=self.tensorizer_config,
)
self.model_memory_usage = m.consumed_memory
......
......@@ -6,7 +6,7 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.neuron_model_loader import get_neuron_model
from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
......
......@@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Set, Tuple
import torch
import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, TensorizerConfig,
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
......@@ -38,12 +38,12 @@ class Worker(WorkerBase):
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
tensorizer_config: Optional[TensorizerConfig] = None,
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config
......@@ -55,7 +55,7 @@ class Worker(WorkerBase):
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.tensorizer_config = tensorizer_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
......@@ -70,11 +70,11 @@ class Worker(WorkerBase):
parallel_config,
scheduler_config,
device_config,
load_config=load_config,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config,
tensorizer_config=tensorizer_config,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment