Commit 6d2051cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev

parents 2c7f740a a2c71c54
...@@ -920,13 +920,10 @@ def get_rope( ...@@ -920,13 +920,10 @@ def get_rope(
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype) is_neox_style, dtype)
else: else:
scaling_type = rope_scaling[ scaling_type = rope_scaling["rope_type"]
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if scaling_type not in {"su", "longrope"}:
scaling_factor = rope_scaling.get("factor", 1.0)
if scaling_type == "llama3": if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"] low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling[ original_max_position = rope_scaling[
...@@ -937,16 +934,39 @@ def get_rope( ...@@ -937,16 +934,39 @@ def get_rope(
scaling_factor, low_freq_factor, scaling_factor, low_freq_factor,
high_freq_factor, high_freq_factor,
original_max_position) original_max_position)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
)
else:
rotary_emb = RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "linear": elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"]
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base, max_position, base,
is_neox_style, is_neox_style,
scaling_factor, dtype) scaling_factor, dtype)
elif scaling_type == "dynamic": elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding( rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor, dtype) scaling_factor, dtype)
elif scaling_type == "yarn": elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[ original_max_position = rope_scaling[
"original_max_position_embeddings"] "original_max_position_embeddings"]
extra_kwargs = { extra_kwargs = {
...@@ -961,6 +981,7 @@ def get_rope( ...@@ -961,6 +981,7 @@ def get_rope(
scaling_factor, dtype, scaling_factor, dtype,
**extra_kwargs) **extra_kwargs)
elif scaling_type == "deepseek_yarn": elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[ original_max_position = rope_scaling[
"original_max_position_embeddings"] "original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor # assert max_position == original_max_position * scaling_factor
...@@ -973,9 +994,7 @@ def get_rope( ...@@ -973,9 +994,7 @@ def get_rope(
rotary_emb = DeepseekScalingRotaryEmbedding( rotary_emb = DeepseekScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base, head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, dtype, **extra_kwargs) is_neox_style, scaling_factor, dtype, **extra_kwargs)
# The correct one should be "longrope" but keep "su" here elif scaling_type == "longrope":
# for backward compatible
elif scaling_type == "su" or scaling_type == "longrope":
short_factor = rope_scaling["short_factor"] short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"] long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling[ original_max_position = rope_scaling[
...@@ -989,16 +1008,6 @@ def get_rope( ...@@ -989,16 +1008,6 @@ def get_rope(
head_size, rotary_dim, max_position, original_max_position, head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, dtype, short_factor, long_factor, base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs) **extra_kwargs)
elif scaling_type == "mrope":
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
)
else: else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb _ROPE_DICT[key] = rotary_emb
......
...@@ -4,7 +4,7 @@ import warnings ...@@ -4,7 +4,7 @@ import warnings
from dataclasses import dataclass from dataclasses import dataclass
from importlib.util import find_spec from importlib.util import find_spec
from math import inf from math import inf
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, Iterator, List, Optional, Tuple, Union
import msgspec import msgspec
import torch import torch
...@@ -117,12 +117,15 @@ class SamplerOutput( ...@@ -117,12 +117,15 @@ class SamplerOutput(
# block/sync across workers, cpu-gpu sync time and sampling time. # block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time: Optional[float] = None model_execute_time: Optional[float] = None
def __getitem__(self, idx: int): def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
return self.outputs[idx] return self.outputs[idx]
def __setitem__(self, idx: int, value): def __setitem__(self, idx: int, value):
self.outputs[idx] = value self.outputs[idx] = value
def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]:
return iter(self.outputs)
def __len__(self): def __len__(self):
return len(self.outputs) return len(self.outputs)
...@@ -508,7 +511,7 @@ def _random_sample( ...@@ -508,7 +511,7 @@ def _random_sample(
same as the length of selected_seq_groups. If the corresponding same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], []) seq_group has do_sample=False, tuple contains ([], [])
""" """
# Find the maximum best_of value of the prompt phase requests. # Find the maximum n value of the prompt phase requests.
random_samples = random_samples.cpu() random_samples = random_samples.cpu()
sample_idx = 0 sample_idx = 0
results: SampleResultType = [] results: SampleResultType = []
...@@ -523,9 +526,9 @@ def _random_sample( ...@@ -523,9 +526,9 @@ def _random_sample(
num_parent_seqs = len(seq_ids) num_parent_seqs = len(seq_ids)
if is_prompt: if is_prompt:
# Prompt phase. # Prompt phase.
parent_ids = [0] * sampling_params.best_of parent_ids = [0] * sampling_params.n
next_token_ids = random_samples[ next_token_ids = random_samples[
sample_idx, :sampling_params.best_of].tolist() sample_idx, :sampling_params.n].tolist()
else: else:
# Generation phase. # Generation phase.
parent_ids = list(range(num_parent_seqs)) parent_ids = list(range(num_parent_seqs))
...@@ -570,7 +573,7 @@ def _beam_search_sample( ...@@ -570,7 +573,7 @@ def _beam_search_sample(
is_prompt = seq_group.is_prompt is_prompt = seq_group.is_prompt
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
num_parent_seqs = len(seq_ids) num_parent_seqs = len(seq_ids)
beam_width = sampling_params.best_of beam_width = sampling_params.n
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs] seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
if is_prompt: if is_prompt:
# Prompt phase. # Prompt phase.
...@@ -797,12 +800,11 @@ def _sample_with_torch( ...@@ -797,12 +800,11 @@ def _sample_with_torch(
greedy_samples) greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of_in_batch = 1 max_n_in_batch = 1
for seq_group in seq_groups: for seq_group in seq_groups:
if seq_group.is_prompt: if seq_group.is_prompt:
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch, max_n_in_batch = max(max_n_in_batch, sampling_params.n)
sampling_params.best_of)
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
seq_groups) seq_groups)
...@@ -812,13 +814,13 @@ def _sample_with_torch( ...@@ -812,13 +814,13 @@ def _sample_with_torch(
probs[long_sample_indices], probs[long_sample_indices],
sampling_tensors.top_ks[long_sample_indices], sampling_tensors.top_ks[long_sample_indices],
sampling_tensors.top_ps[long_sample_indices], sampling_tensors.top_ps[long_sample_indices],
max_best_of_in_batch, max_n_in_batch,
seq_groups_arg, seq_groups_arg,
) )
else: else:
multinomial_samples[sampling_type] = _multinomial( multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices], probs[long_sample_indices],
max_best_of_in_batch, max_n_in_batch,
seq_groups=seq_groups_arg) seq_groups=seq_groups_arg)
if sampled_token_ids_tensor is not None: if sampled_token_ids_tensor is not None:
...@@ -912,7 +914,7 @@ def get_logprobs( ...@@ -912,7 +914,7 @@ def get_logprobs(
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
sample_results: SampleResultType, sample_results: SampleResultType,
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
"""Return sample lobprobs and prompt logprobs. """Return sample logprobs and prompt logprobs.
The logic consists of 3 parts. The logic consists of 3 parts.
- Select indices to compute logprob from, ranks of token ids, and - Select indices to compute logprob from, ranks of token ids, and
...@@ -947,8 +949,6 @@ def get_logprobs( ...@@ -947,8 +949,6 @@ def get_logprobs(
# largest num logprobs in this API. If every logprobs is None, it will be # largest num logprobs in this API. If every logprobs is None, it will be
# set to -1. # set to -1.
largest_num_logprobs = -1 largest_num_logprobs = -1
# If beam search is enabled.
use_beam_search = False
# Select indices to compute logprob from, ranks of token ids, and the top # Select indices to compute logprob from, ranks of token ids, and the top
# k token ids from logprobs. # k token ids from logprobs.
...@@ -981,8 +981,6 @@ def get_logprobs( ...@@ -981,8 +981,6 @@ def get_logprobs(
largest_num_logprobs = max(largest_num_logprobs, largest_num_logprobs = max(largest_num_logprobs,
sampling_params.logprobs) sampling_params.logprobs)
use_beam_search = use_beam_search or sampling_params.use_beam_search
assert len(next_token_ids) == len(query_indices) assert len(next_token_ids) == len(query_indices)
if len(query_indices) == 0: if len(query_indices) == 0:
...@@ -995,7 +993,7 @@ def get_logprobs( ...@@ -995,7 +993,7 @@ def get_logprobs(
# If largest_num_logprobs == -1, i.e. no logprobs are requested, we can # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
# skip the whole logprob calculation. # skip the whole logprob calculation.
if largest_num_logprobs >= 0 or use_beam_search: if largest_num_logprobs >= 0:
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
next_token_ids_gpu = torch.tensor(next_token_ids, next_token_ids_gpu = torch.tensor(next_token_ids,
device=logprobs.device) device=logprobs.device)
...@@ -1121,13 +1119,12 @@ def _get_sampled_logprob_if_needed( ...@@ -1121,13 +1119,12 @@ def _get_sampled_logprob_if_needed(
"""Compute the sample logprob if needed.""" """Compute the sample logprob if needed."""
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
num_logprobs = seq_group.sampling_params.logprobs num_logprobs = seq_group.sampling_params.logprobs
use_beam_search = seq_group.sampling_params.use_beam_search
sampled_logprobs: SampleLogprobs = [] sampled_logprobs: SampleLogprobs = []
next_token_ids, parent_seq_ids = sample_result next_token_ids, parent_seq_ids = sample_result
if seq_group.do_sample: if seq_group.do_sample:
assert len(next_token_ids) > 0 assert len(next_token_ids) > 0
if num_logprobs is None and not use_beam_search: if num_logprobs is None:
for next_token_id in next_token_ids: for next_token_id in next_token_ids:
# Use a dummy logprob # Use a dummy logprob
sampled_logprobs.append({next_token_id: Logprob(inf)}) sampled_logprobs.append({next_token_id: Logprob(inf)})
......
...@@ -457,7 +457,7 @@ class ParallelLMHead(VocabParallelEmbedding): ...@@ -457,7 +457,7 @@ class ParallelLMHead(VocabParallelEmbedding):
super().__init__(num_embeddings, embedding_dim, params_dtype, super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config, org_num_embeddings, padding_size, quant_config,
prefix) prefix)
self.quant_config = quant_config
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition, torch.empty(self.num_embeddings_per_partition,
...@@ -469,6 +469,15 @@ class ParallelLMHead(VocabParallelEmbedding): ...@@ -469,6 +469,15 @@ class ParallelLMHead(VocabParallelEmbedding):
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
def tie_weights(self, embed_tokens: VocabParallelEmbedding):
"""Tie the weights with word embeddings."""
# GGUF quantized embed_tokens.
if self.quant_config and self.quant_config.get_name() == "gguf":
return embed_tokens
else:
self.weight = embed_tokens.weight
return self
def forward(self, input_): def forward(self, input_):
del input_ del input_
raise RuntimeError("LMHead's weights should be used in the sampler.") raise RuntimeError("LMHead's weights should be used in the sampler.")
...@@ -41,9 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -41,9 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import (
get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator, get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator,
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator) safetensors_weights_iterator)
from vllm.model_executor.models.interfaces import (has_inner_state, from vllm.model_executor.models import (has_inner_state, supports_lora,
supports_lora, supports_multimodal)
supports_multimodal)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
...@@ -444,6 +443,18 @@ class DummyModelLoader(BaseModelLoader): ...@@ -444,6 +443,18 @@ class DummyModelLoader(BaseModelLoader):
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights. # random values to the weights.
initialize_dummy_weights(model) initialize_dummy_weights(model)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(
module, torch.device(device_config.device)):
quant_method.process_weights_after_loading(module)
return model.eval() return model.eval()
...@@ -728,15 +739,26 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -728,15 +739,26 @@ class ShardedStateLoader(BaseModelLoader):
class BitsAndBytesModelLoader(BaseModelLoader): class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quantization.""" """Model loader to load model weights with BitAndBytes quantization."""
# TODO: these module names are for Llama only, possible_config_file_names = ["adapter_config.json"]
# change so that it works with other models as well
default_target_modules = [ default_target_modules = [
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", ".gate_proj.",
"o_proj" ".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
'.fc1.',
'.fc2.',
'.dense.',
'.query_key_value.',
'.qkv_proj.',
'.dense_h_to_4h.',
'.dense_4h_to_h.',
'.out_proj.',
] ]
possible_config_file_names = ["adapter_config.json"]
def __init__(self, load_config: LoadConfig): def __init__(self, load_config: LoadConfig):
super().__init__(load_config) super().__init__(load_config)
...@@ -746,7 +768,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -746,7 +768,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if (not load_config.model_loader_extra_config if (not load_config.model_loader_extra_config
or "qlora_adapter_name_or_path" or "qlora_adapter_name_or_path"
not in load_config.model_loader_extra_config): not in load_config.model_loader_extra_config):
self.target_modules = self.default_target_modules self.target_modules = []
return return
qlora_adapter = load_config.model_loader_extra_config[ qlora_adapter = load_config.model_loader_extra_config[
...@@ -893,10 +915,11 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -893,10 +915,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for weight_name, weight_tensor in self._hf_weight_iter( for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors): hf_weights_files, use_safetensors):
if not weight_name.endswith(".weight"): if not weight_name.endswith((".weight", ".bias")):
continue continue
qweight_name = weight_name.replace(".weight", ".qweight") qweight_name = weight_name.replace(".weight", ".qweight")
if qweight_name in quant_state_dict: if qweight_name in quant_state_dict:
set_weight_attrs(weight_tensor, {"load_in_8bit": True}) set_weight_attrs(weight_tensor, {"load_in_8bit": True})
yield qweight_name, weight_tensor yield qweight_name, weight_tensor
...@@ -912,7 +935,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -912,7 +935,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
use_safetensors) use_safetensors)
temp_state_dict = {} temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator: for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith(".weight"): if weight_name.endswith((".weight", ".bias")):
continue continue
# bitsandbytes library requires # bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU # weight.quant_state.bitsandbytes__* in CPU
...@@ -935,9 +958,10 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -935,9 +958,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# pre quantized weights would have a quant_state # pre quantized weights would have a quant_state
for weight_name, weight_tensor in self._hf_weight_iter( for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors): hf_weights_files, use_safetensors):
# Filter out all weights whose suffix is not ".weight"
if not weight_name.endswith(".weight"): if not weight_name.endswith((".weight", ".bias")):
continue continue
if (f"{weight_name}.quant_state.bitsandbytes__nf4" \ if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
in temp_state_dict) or \ in temp_state_dict) or \
(f"{weight_name}.quant_state.bitsandbytes__fp4" \ (f"{weight_name}.quant_state.bitsandbytes__fp4" \
...@@ -957,15 +981,14 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -957,15 +981,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for weight_name, weight_tensor in self._hf_weight_iter( for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors): hf_weights_files, use_safetensors):
if any(target_module in weight_name
for target_module in self.target_modules): if any(target_module in weight_name for target_module in
self.target_modules) and weight_name.endswith(".weight"):
weight_name = weight_name.replace(".weight", ".qweight") weight_name = weight_name.replace(".weight", ".qweight")
# weight partitions of different modules occur at if any(module in weight_name
# different dimensions for module in self.column_parallel_weights_modules):
# TODO: these module names are for Llama only,
# change so that it works with other models as well
if 'down_proj' in weight_name or 'o_proj' in weight_name:
total_size = weight_tensor.size(-1) total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1) end_index = total_size // tp_size * (tp_rank + 1)
...@@ -1014,6 +1037,20 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -1014,6 +1037,20 @@ class BitsAndBytesModelLoader(BaseModelLoader):
f"Model {type(model).__name__} does not support BitsAndBytes " f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet.") "quantization yet.")
if len(self.target_modules) == 0:
if hasattr(model, 'default_bitsandbytes_target_modules'):
self.target_modules = model.default_bitsandbytes_target_modules
else:
self.target_modules = self.default_target_modules
if hasattr(model, 'column_parallel_weights_modules'):
self.column_parallel_weights_modules = \
model.column_parallel_weights_modules
else:
self.column_parallel_weights_modules = []
self.model_type = type(model).__name__
logger.info("Loading weights with BitsAndBytes quantization. " logger.info("Loading weights with BitsAndBytes quantization. "
" May take a while ...") " May take a while ...")
......
"""Utilities for selecting and loading neuron models.""" """Utilities for selecting and loading neuron models."""
import copy
import importlib import importlib
import os import os
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
...@@ -13,6 +14,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -13,6 +14,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import get_quantization_config from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceOutput)
TORCH_DTYPE_TO_NEURON_AMP = { TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "f32", "auto": "f32",
...@@ -37,15 +40,18 @@ _NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = { ...@@ -37,15 +40,18 @@ _NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
class NeuronCasualLM(nn.Module): class NeuronCasualLM(nn.Module):
def __init__( def __init__(self,
self, config: PretrainedConfig,
config: PretrainedConfig, on_device_sampling_disabled: bool = False) -> None:
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size, self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True) logits_as_input=True)
self.sampler = Sampler()
self.on_device_sampling_disabled = on_device_sampling_disabled
if self.on_device_sampling_disabled:
# Use default sampler
self.sampler = Sampler()
# Lazy initialized # Lazy initialized
self.model: nn.Module self.model: nn.Module
...@@ -71,8 +77,29 @@ class NeuronCasualLM(nn.Module): ...@@ -71,8 +77,29 @@ class NeuronCasualLM(nn.Module):
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens if self.on_device_sampling_disabled:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
# On-device sampling outputs the token ids directly.
sampled_token_ids = logits.flatten()
next_tokens = []
sample_idx = 0
for seq_group in sampling_metadata.seq_groups:
samples = []
for seq_id in seq_group.seq_ids:
token_id = sampled_token_ids[sample_idx].item()
samples.append(
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs={token_id: Logprob(token_id)}))
sample_idx += 1
next_tokens.append(
CompletionSequenceGroupOutput(samples=samples,
prompt_logprobs=None))
return SamplerOutput(outputs=next_tokens)
def load_weights(self, model_name_or_path: str, **kwargs): def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config) arch = _get_model_architecture(self.config)
...@@ -157,10 +184,22 @@ def _get_default_neuron_config(model_config: ModelConfig, ...@@ -157,10 +184,22 @@ def _get_default_neuron_config(model_config: ModelConfig,
quant=neuron_quantization_config_builder(model_config.quantization) quant=neuron_quantization_config_builder(model_config.quantization)
if model_config.quantization else None, if model_config.quantization else None,
continuous_batching=continuous_batching_config, continuous_batching=continuous_batching_config,
weight_tiling=bool(model_config.quantization)) weight_tiling=bool(model_config.quantization),
on_device_generation=_get_neuron_on_device_generation_config(
model_config))
return default_neuron_args return default_neuron_args
def _get_neuron_on_device_generation_config(model_config: ModelConfig):
if not _is_neuron_on_device_sampling_disabled(model_config):
return copy.deepcopy(model_config.neuron_sampling_params)
return None
def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
return not getattr(model_config, "neuron_sampling_params", None)
def _get_neuron_config_after_override(default_neuron_config, def _get_neuron_config_after_override(default_neuron_config,
overridden_neuron_config): overridden_neuron_config):
from transformers_neuronx.config import NeuronConfig from transformers_neuronx.config import NeuronConfig
...@@ -174,7 +213,9 @@ def get_neuron_model(model_config: ModelConfig, ...@@ -174,7 +213,9 @@ def get_neuron_model(model_config: ModelConfig,
scheduler_config: SchedulerConfig) -> nn.Module: scheduler_config: SchedulerConfig) -> nn.Module:
# Create a model instance. # Create a model instance.
model = NeuronCasualLM(model_config.hf_config) model = NeuronCasualLM(
model_config.hf_config,
_is_neuron_on_device_sampling_disabled(model_config))
default_neuron_config_args = _get_default_neuron_config( default_neuron_config_args = _get_default_neuron_config(
model_config, parallel_config, scheduler_config) model_config, parallel_config, scheduler_config)
......
...@@ -12,6 +12,7 @@ from torch import nn ...@@ -12,6 +12,7 @@ from torch import nn
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import DeviceConfig, ModelConfig from vllm.config import DeviceConfig, ModelConfig
from vllm.executor.openvino_executor import is_openvino_cpu
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import (LogitsProcessor, from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
_prune_hidden_states) _prune_hidden_states)
...@@ -51,25 +52,15 @@ def _modify_cache_parameters(model: ov.Model, kv_cache_dtype: ov.Type, ...@@ -51,25 +52,15 @@ def _modify_cache_parameters(model: ov.Model, kv_cache_dtype: ov.Type,
shape = parameter.get_partial_shape() shape = parameter.get_partial_shape()
# use real block size if available, just a placeholder # use real block size if available, just a placeholder
# to provide the expected rank # to provide the expected rank
x_size = 1
num_blocks = ov.Dimension() num_blocks = ov.Dimension()
block_size = ov.Dimension() block_size = ov.Dimension()
head_size = ov.Dimension() head_size = ov.Dimension()
# TODO: Negotiate required layout with plugins (CPU is ~OK, GPU is TBD),
# pass more parameters to this function to set more static dimensions
if input_name.startswith("key_cache."): if input_name.startswith("key_cache."):
cpu_shape = [num_blocks, shape[1], block_size, head_size] cpu_shape = [num_blocks, shape[1], block_size, head_size]
gpu_shape = [ gpu_shape = [num_blocks, shape[1], shape[2], block_size]
num_blocks,
shape[1],
shape[2].get_length() //
x_size if shape[2].is_static else ov.Dimension(),
block_size,
x_size,
]
elif input_name.startswith("value_cache."): elif input_name.startswith("value_cache."):
cpu_shape = [num_blocks, shape[1], block_size, head_size] cpu_shape = [num_blocks, shape[1], block_size, head_size]
gpu_shape = [num_blocks, shape[1], shape[2], block_size] gpu_shape = [num_blocks, shape[1], block_size, shape[2]]
else: else:
continue continue
parameter.set_partial_shape( parameter.set_partial_shape(
...@@ -108,6 +99,7 @@ class OpenVINOCasualLM(nn.Module): ...@@ -108,6 +99,7 @@ class OpenVINOCasualLM(nn.Module):
def __init__( def __init__(
self, self,
ov_core: ov.Core,
model_config: ModelConfig, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
kv_cache_dtype: ov.Type, kv_cache_dtype: ov.Type,
...@@ -141,12 +133,12 @@ class OpenVINOCasualLM(nn.Module): ...@@ -141,12 +133,12 @@ class OpenVINOCasualLM(nn.Module):
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
) )
ov_device = envs.VLLM_OPENVINO_DEVICE
paged_attention_transformation(pt_model.model) paged_attention_transformation(pt_model.model)
_modify_cache_parameters(pt_model.model, kv_cache_dtype, _modify_cache_parameters(pt_model.model, kv_cache_dtype,
device_config.device.type == "cpu") is_openvino_cpu())
core = ov.Core() ov_compiled = ov_core.compile_model(pt_model.model, ov_device)
ov_compiled = core.compile_model(pt_model.model, "CPU")
self.ov_request = ov_compiled.create_infer_request() self.ov_request = ov_compiled.create_infer_request()
def forward( def forward(
...@@ -199,6 +191,7 @@ def get_model( ...@@ -199,6 +191,7 @@ def get_model(
**kwargs, **kwargs,
) -> torch.nn.Module: ) -> torch.nn.Module:
lora_config = kwargs.get("lora_config", None) lora_config = kwargs.get("lora_config", None)
ov_core = kwargs.get("ov_core")
if lora_config: if lora_config:
raise ValueError( raise ValueError(
"OpenVINO modeling does not support LoRA, " "OpenVINO modeling does not support LoRA, "
...@@ -206,4 +199,5 @@ def get_model( ...@@ -206,4 +199,5 @@ def get_model(
"be added in the future. If this is important to you, " "be added in the future. If this is important to you, "
"please open an issue on github.") "please open an issue on github.")
return OpenVINOCasualLM(model_config, device_config, kv_cache_dtype) return OpenVINOCasualLM(ov_core, model_config, device_config,
kv_cache_dtype)
...@@ -22,15 +22,15 @@ def set_default_torch_dtype(dtype: torch.dtype): ...@@ -22,15 +22,15 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture( def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
visual = getattr(model_config.hf_config, "visual", []) visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM'] support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
if architectures == ['QWenLMHeadModel'] and visual != []: if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
else: else:
os.environ['LLAMA_NN'] = '1' os.environ['LLAMA_NN'] = '1'
if architectures == ['BloomForCausalLM']: if architectures == ['BloomForCausalLM'] or architectures ==['Qwen2ForCausalLM'] or architectures == ['LlamaForCausalLM']:
os.environ['LM_TN'] = '1' os.environ['LM_TN'] = '1'
else: else:
os.environ['LM_TN'] = '0' os.environ['LM_TN'] = '0'
...@@ -57,7 +57,9 @@ def get_model_architecture( ...@@ -57,7 +57,9 @@ def get_model_architecture(
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] mixtral_supported = [
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"
]
if (model_config.quantization is not None if (model_config.quantization is not None
and model_config.quantization not in mixtral_supported and model_config.quantization not in mixtral_supported
......
...@@ -6,7 +6,8 @@ import json ...@@ -6,7 +6,8 @@ import json
import os import os
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional,
Tuple, Union)
import filelock import filelock
import gguf import gguf
...@@ -498,8 +499,8 @@ def kv_cache_scales_loader( ...@@ -498,8 +499,8 @@ def kv_cache_scales_loader(
logger.error("File or directory '%s' not found.", filename) logger.error("File or directory '%s' not found.", filename)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.error("Error decoding JSON in file '%s'.", filename) logger.error("Error decoding JSON in file '%s'.", filename)
except Exception as e: except Exception:
logger.error("An error occurred while reading '%s': %s", filename, e) logger.exception("An error occurred while reading '%s'.", filename)
# This section is reached if and only if any of the excepts are hit # This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded # Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales # which ultimately defaults to 1.0 scales
...@@ -559,6 +560,38 @@ def row_parallel_weight_loader(param: torch.Tensor, ...@@ -559,6 +560,38 @@ def row_parallel_weight_loader(param: torch.Tensor,
return default_weight_loader(param, loaded_weight) return default_weight_loader(param, loaded_weight)
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
"""Create a weight loader that shards the weights along the given axis"""
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
tp_rank = get_tensor_model_parallel_rank()
shard_size = param.data.shape[shard_axis]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size)
return default_weight_loader(param, loaded_weight)
return loader
def composed_weight_loader(
loader: LoaderFunction, fn: Callable[[torch.Tensor],
torch.Tensor]) -> LoaderFunction:
"""Create a weight loader that post-processes the weights after loading"""
def composed_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
loader(param, loaded_weight)
param.data.copy_(fn(param))
return
return composed_loader
def initialize_dummy_weights( def initialize_dummy_weights(
model: torch.nn.Module, model: torch.nn.Module,
low: float = -1e-3, low: float = -1e-3,
......
import functools from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
import importlib SupportsPP, has_inner_state, supports_lora,
from typing import Dict, List, Optional, Tuple, Type supports_multimodal, supports_pp)
from .interfaces_base import (VllmModelForEmbedding,
import torch.nn as nn VllmModelForTextGeneration, is_embedding_model,
is_text_generation_model)
from vllm.logger import init_logger from .registry import ModelRegistry
from vllm.utils import is_hip
logger = init_logger(__name__)
_GENERATION_MODELS = {
"AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"Qwen2VLForConditionalGeneration":
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"TeleChat12BForCausalLM": ("telechat_12B", "TeleChat12BForCausalLM"), # telechat12b
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MedusaModel": ("medusa", "Medusa"),
"EAGLEModel": ("eagle", "EAGLE"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"GraniteForCausalLM": ("granite", "GraniteForCausalLM")
}
_EMBEDDING_MODELS = {
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
}
_MULTIMODAL_MODELS = {
"Blip2ForConditionalGeneration":
("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration":
("chameleon", "ChameleonForConditionalGeneration"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"LlavaForConditionalGeneration": ("llava",
"LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration": ("llava_next",
"LlavaNextForConditionalGeneration"),
"LlavaNextVideoForConditionalGeneration":
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
"LlavaOnevisionForConditionalGeneration":
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"PixtralForConditionalGeneration": ("pixtral",
"PixtralForConditionalGeneration"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
"Qwen2VLForConditionalGeneration"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
"MllamaForConditionalGeneration": ("mllama",
"MllamaForConditionalGeneration"),
}
_CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
}
_MODELS = {
**_GENERATION_MODELS,
**_EMBEDDING_MODELS,
**_MULTIMODAL_MODELS,
**_CONDITIONAL_GENERATION_MODELS,
}
# Architecture -> type.
# out of tree models
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
# "Qwen2ForCausalLM":
# _ROCM_SWA_REASON,
# "MistralForCausalLM":
# _ROCM_SWA_REASON,
# "MixtralForCausalLM":
# _ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration":
("ROCm flash attention does not yet "
"fully support 32-bit precision on PaliGemma"),
"Phi3VForCausalLM":
("ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}
class ModelRegistry:
@staticmethod
@functools.lru_cache(maxsize=128)
def _get_model(model_arch: str):
module_name, model_cls_name = _MODELS[model_arch]
module = importlib.import_module(
f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None)
@staticmethod
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]
if model_arch not in _MODELS:
return None
if is_hip():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
f"Model architecture {model_arch} is not supported by "
"ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning(
"Model architecture %s is partially supported by ROCm: %s",
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
return ModelRegistry._get_model(model_arch)
@staticmethod
def resolve_model_cls(
architectures: List[str]) -> Tuple[Type[nn.Module], str]:
for arch in architectures:
model_cls = ModelRegistry._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
@staticmethod
def get_supported_archs() -> List[str]:
return list(_MODELS.keys()) + list(_OOT_MODELS.keys())
@staticmethod
def register_model(model_arch: str, model_cls: Type[nn.Module]):
if model_arch in _MODELS:
logger.warning(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls.__name__)
global _OOT_MODELS
_OOT_MODELS[model_arch] = model_cls
@staticmethod
def is_embedding_model(model_arch: str) -> bool:
return model_arch in _EMBEDDING_MODELS
@staticmethod
def is_multimodal_model(model_arch: str) -> bool:
# TODO: find a way to avoid initializing CUDA prematurely to
# use `supports_multimodal` to determine if a model is multimodal
# model_cls = ModelRegistry._try_load_model_cls(model_arch)
# from vllm.model_executor.models.interfaces import supports_multimodal
return model_arch in _MULTIMODAL_MODELS
__all__ = [ __all__ = [
"ModelRegistry", "ModelRegistry",
] "VllmModelForEmbedding",
"is_embedding_model",
"VllmModelForTextGeneration",
"is_text_generation_model",
"HasInnerState",
"has_inner_state",
"SupportsLoRA",
"supports_lora",
"SupportsMultiModal",
"supports_multimodal",
"SupportsPP",
"supports_pp",
]
\ No newline at end of file
"""Inference-only Snowflake Arctic model.""" """Inference-only Snowflake Arctic model."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -18,8 +18,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -18,8 +18,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.quantization.deepspeedfp import ( from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig, DeepSpeedFPParameter) DeepSpeedFPConfig, DeepSpeedFPParameter)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
...@@ -32,6 +31,10 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -32,6 +31,10 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.arctic import ArcticConfig from vllm.transformers_utils.configs.arctic import ArcticConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -364,6 +367,7 @@ class ArcticModel(nn.Module): ...@@ -364,6 +367,7 @@ class ArcticModel(nn.Module):
config: ArcticConfig, config: ArcticConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -372,15 +376,16 @@ class ArcticModel(nn.Module): ...@@ -372,15 +376,16 @@ class ArcticModel(nn.Module):
self.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=self.vocab_size) org_num_embeddings=self.vocab_size)
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
ArcticDecoderLayer(config, config.num_hidden_layers,
layer_idx, lambda prefix: ArcticDecoderLayer(config, int(
cache_config, prefix.split(".")[-1]), cache_config, quant_config),
quant_config=quant_config) prefix=f"{prefix}.layers")
for layer_idx in range(config.num_hidden_layers)
])
self._attn_implementation = config._attn_implementation self._attn_implementation = config._attn_implementation
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward( def forward(
self, self,
...@@ -388,17 +393,25 @@ class ArcticModel(nn.Module): ...@@ -388,17 +393,25 @@ class ArcticModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_tokens(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
for i in range(len(self.layers)): if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer(positions, hidden_states, kv_caches[i], hidden_states = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata) attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
class ArcticForCausalLM(nn.Module): class ArcticForCausalLM(nn.Module, SupportsPP):
def __init__(self, def __init__(self,
config: ArcticConfig, config: ArcticConfig,
...@@ -422,6 +435,8 @@ class ArcticForCausalLM(nn.Module): ...@@ -422,6 +435,8 @@ class ArcticForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -430,9 +445,9 @@ class ArcticForCausalLM(nn.Module): ...@@ -430,9 +445,9 @@ class ArcticForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -503,6 +518,8 @@ class ArcticForCausalLM(nn.Module): ...@@ -503,6 +518,8 @@ class ArcticForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -512,6 +529,8 @@ class ArcticForCausalLM(nn.Module): ...@@ -512,6 +529,8 @@ class ArcticForCausalLM(nn.Module):
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -522,6 +541,8 @@ class ArcticForCausalLM(nn.Module): ...@@ -522,6 +541,8 @@ class ArcticForCausalLM(nn.Module):
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, weight_loader(param,
...@@ -532,6 +553,8 @@ class ArcticForCausalLM(nn.Module): ...@@ -532,6 +553,8 @@ class ArcticForCausalLM(nn.Module):
else: else:
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights.""" """Inference-only BaiChuan model compatible with HuggingFace weights."""
import math import math
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -29,7 +29,7 @@ import re ...@@ -29,7 +29,7 @@ import re
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -37,8 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -37,8 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -47,7 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -47,7 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
...@@ -269,7 +270,8 @@ class BaiChuanModel(nn.Module): ...@@ -269,7 +270,8 @@ class BaiChuanModel(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
position_embedding: str, position_embedding: str,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -279,12 +281,16 @@ class BaiChuanModel(nn.Module): ...@@ -279,12 +281,16 @@ class BaiChuanModel(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
BaiChuanDecoderLayer(config, position_embedding, cache_config, config.num_hidden_layers,
quant_config) lambda prefix: BaiChuanDecoderLayer(config, position_embedding,
for _ in range(config.num_hidden_layers) cache_config, quant_config),
]) prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
...@@ -292,23 +298,34 @@ class BaiChuanModel(nn.Module): ...@@ -292,23 +298,34 @@ class BaiChuanModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_tokens(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
residual = None if get_pp_group().is_first_rank:
for i in range(len(self.layers)): hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual,
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"W_pack": ["W_pack"], "W_pack": ["W_pack"],
"gate_up_proj": [ "gate_up_proj": [
...@@ -349,6 +366,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): ...@@ -349,6 +366,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.quant_method = None self.quant_method = None
if quant_config is not None: if quant_config is not None:
...@@ -359,7 +378,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): ...@@ -359,7 +378,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -367,9 +386,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): ...@@ -367,9 +386,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -418,6 +437,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): ...@@ -418,6 +437,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -426,6 +447,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): ...@@ -426,6 +447,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
...@@ -503,13 +526,12 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): ...@@ -503,13 +526,12 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda() qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous() qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 13B and Baichuan2 7B/13B.""" """Baichuan 13B and Baichuan2 7B/13B."""
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
...@@ -527,7 +549,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): ...@@ -527,7 +549,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
......
...@@ -10,7 +10,7 @@ from transformers.models.blip.modeling_blip import BlipAttention ...@@ -10,7 +10,7 @@ from transformers.models.blip.modeling_blip import BlipAttention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs from vllm.inputs import DecoderOnlyInputs, token_inputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -63,7 +63,7 @@ def dummy_seq_data_for_blip( ...@@ -63,7 +63,7 @@ def dummy_seq_data_for_blip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
return SequenceData.from_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) )
...@@ -89,14 +89,14 @@ def dummy_image_for_blip( ...@@ -89,14 +89,14 @@ def dummy_image_for_blip(
def input_processor_for_blip( def input_processor_for_blip(
model_config: ModelConfig, model_config: ModelConfig,
hf_config: Union[BlipVisionConfig, Blip2VisionConfig], hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
llm_inputs: LLMInputs, inputs: DecoderOnlyInputs,
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[int] = None, image_feature_size_override: Optional[int] = None,
): ):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
...@@ -107,16 +107,16 @@ def input_processor_for_blip( ...@@ -107,16 +107,16 @@ def input_processor_for_blip(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
llm_inputs.get("prompt"), inputs.get("prompt"),
llm_inputs["prompt_token_ids"], inputs["prompt_token_ids"],
placeholder_token_id=image_token_id, placeholder_token_id=image_token_id,
repeat_count=image_feature_size, repeat_count=image_feature_size,
) )
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
......
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -8,19 +9,19 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig, ...@@ -8,19 +9,19 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from .blip import (BlipVisionModel, dummy_image_for_blip, from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens) get_max_blip_image_tokens)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (group_weights_with_prefix, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings) merge_multimodal_embeddings)
# We use this internally as placeholders since there is no image token # We use this internally as placeholders since there is no image token
...@@ -421,7 +422,7 @@ def dummy_seq_data_for_blip2( ...@@ -421,7 +422,7 @@ def dummy_seq_data_for_blip2(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
return SequenceData.from_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) )
...@@ -449,10 +450,10 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int, ...@@ -449,10 +450,10 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
raise NotImplementedError(msg) raise NotImplementedError(msg)
def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
hf_config = ctx.get_hf_config(Blip2Config) hf_config = ctx.get_hf_config(Blip2Config)
image_feature_size = get_blip2_image_feature_size(hf_config) image_feature_size = get_blip2_image_feature_size(hf_config)
...@@ -460,22 +461,22 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -460,22 +461,22 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
# The original model places image tokens at the front # The original model places image tokens at the front
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514 # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
new_token_ids += llm_inputs["prompt_token_ids"] new_token_ids += inputs["prompt_token_ids"]
new_prompt = llm_inputs.get("prompt") new_prompt = inputs.get("prompt")
if new_prompt is not None: if new_prompt is not None:
new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
return LLMInputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2) @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal): class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(self,
config: Blip2Config, config: Blip2Config,
...@@ -508,6 +509,16 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -508,6 +509,16 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config) config.text_config, cache_config, quant_config)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size h = w = self.config.vision_config.image_size
expected_dims = (3, h, w) expected_dims = (3, h, w)
...@@ -600,7 +611,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -600,7 +611,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> Union[SamplerOutput, IntermediateTensors]:
"""Run forward pass for BLIP-2. """Run forward pass for BLIP-2.
One key thing to understand is the `input_ids` already accounts for the One key thing to understand is the `input_ids` already accounts for the
...@@ -631,26 +642,32 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -631,26 +642,32 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
See also: See also:
:class:`Blip2ImageInputs` :class:`Blip2ImageInputs`
""" """
image_input = self._parse_and_validate_image_input(**kwargs) if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None: if image_input is not None:
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings( inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids) input_ids)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
BLIP2_IMAGE_TOKEN_ID) BLIP2_IMAGE_TOKEN_ID)
input_ids = None input_ids = None
else: else:
inputs_embeds = None inputs_embeds = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(
positions, input_ids,
kv_caches, positions,
attn_metadata, kv_caches,
inputs_embeds=inputs_embeds) attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
...@@ -670,35 +687,5 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -670,35 +687,5 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision encoder
self.vision_model.load_weights(weights_group["vision_model"])
# load query tokens
for name, loaded_weight in weights_group["query_tokens"]:
assert name == ""
param = self.query_tokens
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load qformer
qformer_params_dict = dict(self.qformer.named_parameters())
for name, loaded_weight in weights_group["qformer"]:
param = qformer_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load mlp projector
mlp_params_dict = dict(self.language_projection.named_parameters())
for name, loaded_weight in weights_group["language_projection"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights.""" """Inference-only BLOOM model compatible with HuggingFace weights."""
import math import math
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -27,15 +27,14 @@ import re ...@@ -27,15 +27,14 @@ import re
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
...@@ -45,6 +44,10 @@ from vllm.sequence import IntermediateTensors ...@@ -45,6 +44,10 @@ from vllm.sequence import IntermediateTensors
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
...@@ -230,6 +233,7 @@ class BloomModel(nn.Module): ...@@ -230,6 +233,7 @@ class BloomModel(nn.Module):
config: BloomConfig, config: BloomConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -243,13 +247,16 @@ class BloomModel(nn.Module): ...@@ -243,13 +247,16 @@ class BloomModel(nn.Module):
self.embed_dim, eps=config.layer_norm_epsilon) self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks # Transformer blocks
self.h = nn.ModuleList([ self.start_layer, self.end_layer, self.h = make_layers(
BloomBlock(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: BloomBlock(config, cache_config, quant_config),
]) prefix=f"{prefix}.h")
# Final Layer Norm # Final Layer Norm
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward( def forward(
self, self,
...@@ -257,22 +264,29 @@ class BloomModel(nn.Module): ...@@ -257,22 +264,29 @@ class BloomModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.word_embeddings(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.word_embeddings_layernorm(hidden_states) if get_pp_group().is_first_rank:
for i in range(len(self.h)): hidden_states = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i] layer = self.h[i]
hidden_states = layer( hidden_states = layer(
position_ids, position_ids,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
class BloomForCausalLM(nn.Module): class BloomForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -292,6 +306,8 @@ class BloomForCausalLM(nn.Module): ...@@ -292,6 +306,8 @@ class BloomForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
self.quant_method = None self.quant_method = None
if quant_config is not None: if quant_config is not None:
...@@ -309,9 +325,9 @@ class BloomForCausalLM(nn.Module): ...@@ -309,9 +325,9 @@ class BloomForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -338,6 +354,8 @@ class BloomForCausalLM(nn.Module): ...@@ -338,6 +354,8 @@ class BloomForCausalLM(nn.Module):
continue continue
if not name.startswith("transformer."): if not name.startswith("transformer."):
name = "transformer." + name name = "transformer." + name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
if "query_key_value" in name: if "query_key_value" in name:
......
from functools import cached_property from functools import cached_property
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict) Tuple, TypedDict, Union)
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -10,8 +10,9 @@ from transformers import ChameleonConfig, ChameleonVQVAEConfig ...@@ -10,8 +10,9 @@ from transformers import ChameleonConfig, ChameleonVQVAEConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -33,7 +34,9 @@ from vllm.multimodal.utils import (cached_get_tokenizer, ...@@ -33,7 +34,9 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
# These configs are not part of the model config but the preprocessor # These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now. # and processor files, so we hardcode them in the model file for now.
...@@ -67,7 +70,7 @@ def dummy_seq_data_for_chameleon( ...@@ -67,7 +70,7 @@ def dummy_seq_data_for_chameleon(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
return SequenceData.from_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) )
...@@ -104,7 +107,8 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int, ...@@ -104,7 +107,8 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
return seq_data, mm_data return seq_data, mm_data
def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_chameleon(ctx: InputContext,
inputs: DecoderOnlyInputs):
""" """
Processing input prompt to insert required tokens for image placeholder. Processing input prompt to insert required tokens for image placeholder.
...@@ -112,16 +116,16 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -112,16 +116,16 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58 See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58
""" # noqa """ # noqa
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
model_config = ctx.model_config model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
llm_inputs.get("prompt"), inputs.get("prompt"),
llm_inputs["prompt_token_ids"], inputs["prompt_token_ids"],
placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID, placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID,
repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH, repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID, pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
...@@ -135,9 +139,9 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -135,9 +139,9 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
new_token_ids += [CHAMELEON_SEP_TOKEN_ID] new_token_ids += [CHAMELEON_SEP_TOKEN_ID]
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
class ChameleonLayerNorm(nn.LayerNorm): class ChameleonLayerNorm(nn.LayerNorm):
...@@ -822,6 +826,7 @@ class ChameleonModel(nn.Module): ...@@ -822,6 +826,7 @@ class ChameleonModel(nn.Module):
config: ChameleonConfig, config: ChameleonConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -835,14 +840,20 @@ class ChameleonModel(nn.Module): ...@@ -835,14 +840,20 @@ class ChameleonModel(nn.Module):
config.vocabulary_map) config.vocabulary_map)
decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \ decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \
else ChameleonSwinDecoderLayer else ChameleonSwinDecoderLayer
self.layers = nn.ModuleList([
decoder_layer(config=config, self.start_layer, self.end_layer, self.layers = make_layers(
cache_config=cache_config, config.num_hidden_layers,
quant_config=quant_config) lambda prefix: decoder_layer(config=config,
for _ in range(config.num_hidden_layers) cache_config=cache_config,
]) quant_config=quant_config),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.vqmodel = ChameleonVQVAE(config.vq_config) self.vqmodel = ChameleonVQVAE(config.vq_config)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -865,22 +876,33 @@ class ChameleonModel(nn.Module): ...@@ -865,22 +876,33 @@ class ChameleonModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is not None: if get_pp_group().is_first_rank:
hidden_states = inputs_embeds if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else: else:
hidden_states = self.get_input_embeddings(input_ids) assert intermediate_tensors is not None
residual = None hidden_states = intermediate_tensors["hidden_states"]
for i in range(len(self.layers)): residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -889,7 +911,8 @@ class ChameleonModel(nn.Module): ...@@ -889,7 +911,8 @@ class ChameleonModel(nn.Module):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon) @INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal): class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__( def __init__(
self, self,
...@@ -914,6 +937,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -914,6 +937,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale) config.vocab_size, logit_scale)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
...@@ -956,22 +981,26 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -956,22 +981,26 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None: if intermediate_tensors is not None:
assert self.model.vqmodel is not None input_ids = None
image_tokens = self.model.get_image_tokens(image_input["data"].to( else:
self.config.torch_dtype)) image_input = self._parse_and_validate_image_input(**kwargs)
image_token_id = self.model.vocabulary_mapping.image_token_id
special_image_mask = input_ids == image_token_id if image_input is not None:
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) assert self.model.vqmodel is not None
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens = self.model.get_image_tokens(
image_tokens) image_input["data"].to(self.config.torch_dtype))
image_token_id = self.model.vocabulary_mapping.image_token_id
special_image_mask = input_ids == image_token_id
image_tokens = image_tokens.to(input_ids.device,
input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask,
image_tokens)
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -1039,6 +1068,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1039,6 +1068,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -1060,11 +1091,15 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1060,11 +1091,15 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
continue continue
else: else:
name = remapped_kv_scale_name name = remapped_kv_scale_name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if use_default_weight_loading and name in params_dict: if use_default_weight_loading and name in params_dict:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
# coding=utf-8 # coding=utf-8
# Adapted from # Adapted from
# https://github.com/THUDM/ChatGLM2-6B # https://github.com/THUDM/GLM-4
"""Inference-only ChatGLM model compatible with THUDM weights.""" """Inference-only ChatGLM model compatible with THUDM weights."""
from typing import Iterable, List, Optional, Tuple from argparse import Namespace
from array import array
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict
import torch import torch
from PIL import Image
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
import os import os
import re import re
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -26,14 +31,198 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput ...@@ -26,14 +31,198 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalInputs)
from vllm.multimodal.base import MultiModalData
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsMultiModal
logger = init_logger(__name__)
def calculate_image_placeholder(vision_config):
return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2
def mm_input_mapper_for_glmv(
ctx: InputContext,
data: MultiModalData[object],
) -> Dict:
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
if tokenizer is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
try:
raw_batch_data = tokenizer.apply_chat_template(
conversation=[{
"role": "user",
"image": data
}],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True).data
except Exception:
logger.error("Failed to process image (%s)", data)
raise
pixel_values = raw_batch_data['images']
return MultiModalInputs({'pixel_values': pixel_values})
def merge_glm_vision_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
boi_token_id: int,
eoi_token_id: int,
) -> torch.Tensor:
boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0]
eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0]
mask = torch.zeros_like(input_ids, dtype=torch.bool)
for boi_pos, eoi_pos in zip(boi_positions, eoi_positions):
assert boi_pos < eoi_pos
mask[boi_pos:eoi_pos + 1] = True
inputs_embeds[mask] = vision_embeddings.view(-1,
vision_embeddings.shape[-1])
return inputs_embeds
class GLMImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
def get_max_glmv_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(ChatGLMConfig)
vision_config = getattr(hf_config, 'vision_config', None)
if vision_config is None:
return 1
elif isinstance(vision_config, dict):
return calculate_image_placeholder(vision_config)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def dummy_data_for_glmv(
ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
hf_config = ctx.get_hf_config(ChatGLMConfig)
vision_config = getattr(hf_config, 'vision_config', None)
if vision_config is None:
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
seq_data = SequenceData(token_ids)
return seq_data, None
elif isinstance(vision_config, dict):
image_size = vision_config["image_size"]
image_placeholder_length = calculate_image_placeholder(vision_config)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] +
[0] * image_placeholder_length +
[hf_config.eoi_token_id])
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0] * (seq_len - image_placeholder_length - 2))
seq_data = SequenceData(token_ids)
mm_data = {
"image": Image.new("RGB", (image_size, image_size), color=0)
}
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def find_all_positions(input_ids: List[int], target: int) -> List[int]:
return [index for index, value in enumerate(input_ids) if value == target]
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
hf_config = ctx.get_hf_config(ChatGLMConfig)
vision_config = getattr(hf_config, 'vision_config', None)
if vision_config is None:
return inputs
elif isinstance(vision_config, dict):
image_placeholder_length = calculate_image_placeholder(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
input_ids = inputs.get("prompt_token_ids")
position_ids = inputs.get("position_ids")
tokenizer = cached_get_tokenizer(
ctx.model_config.model,
trust_remote_code=ctx.model_config.trust_remote_code)
try:
raw_batch_data = tokenizer.apply_chat_template(
conversation=[{
"role": "user",
"image": inputs['multi_modal_data']["image"],
"content": inputs['prompt']
}],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True).data
except Exception:
logger.error("Failed to process content (%s)", inputs['prompt'])
raise
input_ids = raw_batch_data['input_ids'][0].tolist()
if position_ids is None:
position_ids = list(range(len(input_ids)))
boi_token_id = hf_config.boi_token_id
eoi_token_id = hf_config.eoi_token_id
boi_positions = find_all_positions(input_ids, boi_token_id)
eoi_positions = find_all_positions(input_ids, eoi_token_id)
assert len(boi_positions) == len(eoi_positions)
new_input_ids = []
new_position_ids = []
final_processed_position = 0
final_processed_position = 0
for boi_position, eoi_position in zip(boi_positions, eoi_positions):
assert boi_position < eoi_position
new_input_ids.extend(input_ids[final_processed_position:boi_position +
1])
new_position_ids.extend(
list(range(final_processed_position, boi_position + 1)))
new_input_ids.extend([input_ids[boi_position + 1]] *
image_placeholder_length)
new_position_ids.extend([boi_position + 1] * image_placeholder_length)
final_processed_position = eoi_position
new_input_ids.extend(input_ids[final_processed_position:])
new_position_ids.extend(
list(range(final_processed_position, len(input_ids))))
assert len(new_input_ids) == len(new_position_ids)
inputs["prompt_token_ids"] = new_input_ids
inputs["position_ids"] = new_position_ids
return inputs
>>>>>>> v0.6.3.post1
class GLMAttention(nn.Module): class GLMAttention(nn.Module):
...@@ -306,8 +495,11 @@ class ChatGLMModel(nn.Module): ...@@ -306,8 +495,11 @@ class ChatGLMModel(nn.Module):
): ):
super().__init__() super().__init__()
self.config = config
self.embedding = VocabParallelEmbedding(config.padded_vocab_size, self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
config.hidden_size) config.hidden_size,
quant_config=quant_config)
self.num_layers = config.num_layers self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num self.multi_query_group_num = config.multi_query_group_num
...@@ -318,26 +510,72 @@ class ChatGLMModel(nn.Module): ...@@ -318,26 +510,72 @@ class ChatGLMModel(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
vision_config_flag = getattr(config, 'vision_config', None)
if vision_config_flag is not None:
self.vision_config = Namespace(**config.vision_config)
self.vision = EVA2CLIPModel(self.config, quant_config)
else:
self.vision = None
def _parse_and_validate_image_input(
self, **kwargs: object) -> GLMImagePixelInputs:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is not None and self.vision is not None:
if isinstance(pixel_values, torch.Tensor):
if pixel_values.ndim > 2:
pixel_values = torch.concat(list(pixel_values))
elif isinstance(pixel_values, list):
return torch.concat(pixel_values)
else:
raise TypeError("""pixel_values must be a torch.Tensor
or a list of torch.Tensor
""")
return GLMImagePixelInputs(pixel_values=pixel_values)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.embedding(input_ids) inputs_embeds = self.embedding(input_ids)
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input["pixel_values"] is not None:
pixel_values = image_input["pixel_values"].to(
dtype=inputs_embeds.dtype)
image_embeds = self.vision(pixel_values)
boi_token_id = self.config.boi_token_id
eoi_token_id = self.config.eoi_token_id
inputs_embeds = merge_glm_vision_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
vision_embeddings=image_embeds,
boi_token_id=boi_token_id,
eoi_token_id=eoi_token_id)
# Run encoder. # Run encoder.
hidden_states = self.encoder( hidden_states = self.encoder(
hidden_states=inputs_embeds, hidden_states=inputs_embeds,
position_ids=position_ids, position_ids=positions,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
) )
return hidden_states return hidden_states
class ChatGLMForCausalLM(nn.Module, SupportsLoRA): @MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
packed_modules_mapping = { packed_modules_mapping = {
"query_key_value": ["query_key_value"], "query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"] "dense_h_to_4h": ["dense_h_to_4h"]
...@@ -355,6 +593,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA): ...@@ -355,6 +593,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
def __init__( def __init__(
self, self,
config: ChatGLMConfig, config: ChatGLMConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
...@@ -363,6 +602,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA): ...@@ -363,6 +602,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.multimodal_config = multimodal_config
self.quant_config = quant_config self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length", self.max_position_embeddings = getattr(config, "max_sequence_length",
...@@ -384,16 +624,15 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA): ...@@ -384,16 +624,15 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward( def forward(self,
self, input_ids: torch.Tensor,
input_ids: torch.Tensor, positions: torch.Tensor,
positions: torch.Tensor, kv_caches: List[torch.Tensor],
kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata,
attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None,
intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs) -> torch.Tensor:
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, **kwargs)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -414,8 +653,24 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA): ...@@ -414,8 +653,24 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Merge two ColumnParallelLinear into one MergedColumnParallelLinear
merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
"transformer.vision.linear_proj.merged_proj.weight": {
"transformer.vision.linear_proj.gate_proj.weight": None,
"transformer.vision.linear_proj.dense_h_to_4h.weight": None,
}
}
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights: for name, loaded_weight in weights:
is_weight_to_be_merge = False
for _, merged_weight_dict in merged_weights_dict.items():
if name in merged_weight_dict:
assert merged_weight_dict[name] is None
merged_weight_dict[name] = loaded_weight
is_weight_to_be_merge = True
if is_weight_to_be_merge:
continue
if "rotary_pos_emb.inv_freq" in name: if "rotary_pos_emb.inv_freq" in name:
continue continue
if "word_embeddings" in name: if "word_embeddings" in name:
...@@ -427,6 +682,15 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA): ...@@ -427,6 +682,15 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
for combined_name, merged_weight_dict in merged_weights_dict.items():
if combined_name in params_dict:
param = params_dict[combined_name]
combined_weight = torch.cat(list(merged_weight_dict.values()),
dim=0)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, combined_weight)
if self.use_llama_nn and self.quant_method is None: if self.use_llama_nn and self.quant_method is None:
lay_key_words = [ lay_key_words = [
...@@ -464,3 +728,4 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA): ...@@ -464,3 +728,4 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
weight.data.copy_(_weight) weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1) weight.data=weight.data.reshape(ori_shape[1], -1)
...@@ -11,7 +11,7 @@ from transformers.models.clip.modeling_clip import CLIPSdpaAttention ...@@ -11,7 +11,7 @@ from transformers.models.clip.modeling_clip import CLIPSdpaAttention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs from vllm.inputs import DecoderOnlyInputs, token_inputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -62,7 +62,7 @@ def dummy_seq_data_for_clip( ...@@ -62,7 +62,7 @@ def dummy_seq_data_for_clip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
return SequenceData.from_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) )
...@@ -106,14 +106,14 @@ def dummy_video_for_clip( ...@@ -106,14 +106,14 @@ def dummy_video_for_clip(
def input_processor_for_clip( def input_processor_for_clip(
model_config: ModelConfig, model_config: ModelConfig,
hf_config: CLIPVisionConfig, hf_config: CLIPVisionConfig,
llm_inputs: LLMInputs, inputs: DecoderOnlyInputs,
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[Union[int, List[int]]] = None, image_feature_size_override: Optional[Union[int, List[int]]] = None,
): ):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
...@@ -130,16 +130,16 @@ def input_processor_for_clip( ...@@ -130,16 +130,16 @@ def input_processor_for_clip(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
llm_inputs.get("prompt"), inputs.get("prompt"),
llm_inputs["prompt_token_ids"], inputs["prompt_token_ids"],
placeholder_token_id=image_token_id, placeholder_token_id=image_token_id,
repeat_count=image_feature_size, repeat_count=image_feature_size,
) )
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# This file is based on the LLama model definition file in transformers # This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model.""" """PyTorch Cohere model."""
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -29,25 +29,27 @@ from transformers import CohereConfig ...@@ -29,25 +29,27 @@ from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, row_parallel_weight_loader) default_weight_loader, maybe_remap_kv_scale_name,
row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
@torch.compile @torch.compile
...@@ -82,7 +84,7 @@ class CohereMLP(nn.Module): ...@@ -82,7 +84,7 @@ class CohereMLP(nn.Module):
def __init__( def __init__(
self, self,
config, config: CohereConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -256,6 +258,7 @@ class CohereModel(nn.Module): ...@@ -256,6 +258,7 @@ class CohereModel(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -265,12 +268,16 @@ class CohereModel(nn.Module): ...@@ -265,12 +268,16 @@ class CohereModel(nn.Module):
self.org_vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
CohereDecoderLayer(config, cache_config, quant_config=quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: CohereDecoderLayer(config, cache_config,
]) quant_config),
prefix=f"{prefix}.layers")
self.norm = LayerNorm(param_shape=(config.hidden_size), self.norm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
...@@ -278,23 +285,34 @@ class CohereModel(nn.Module): ...@@ -278,23 +285,34 @@ class CohereModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_tokens(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
residual = None if get_pp_group().is_first_rank:
for i in range(len(self.layers)): hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class CohereForCausalLM(nn.Module, SupportsLoRA): class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -337,6 +355,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA): ...@@ -337,6 +355,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA):
quant_config, quant_config,
lora_config=lora_config) lora_config=lora_config)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -346,9 +366,9 @@ class CohereForCausalLM(nn.Module, SupportsLoRA): ...@@ -346,9 +366,9 @@ class CohereForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -393,6 +413,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA): ...@@ -393,6 +413,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -405,6 +427,13 @@ class CohereForCausalLM(nn.Module, SupportsLoRA): ...@@ -405,6 +427,13 @@ class CohereForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
# coding=utf-8 # coding=utf-8
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class DbrxRouter(nn.Module): class DbrxRouter(nn.Module):
"""A Router implementation for DBRX that returns logits for each expert """A Router implementation for DBRX that returns logits for each expert
...@@ -296,22 +300,27 @@ class DbrxModel(nn.Module): ...@@ -296,22 +300,27 @@ class DbrxModel(nn.Module):
config: DbrxConfig, config: DbrxConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.wte = VocabParallelEmbedding( self.wte = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.d_model, config.d_model,
) )
self.blocks = nn.ModuleList([ self.start_layer, self.end_layer, self.blocks = make_layers(
DbrxBlock(config, cache_config, quant_config) config.n_layers,
for _ in range(config.n_layers) lambda prefix: DbrxBlock(config, cache_config, quant_config),
]) prefix=f"{prefix}.blocks",
)
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
for module in self.modules(): for module in self.modules():
if hasattr(module, "bias") and isinstance(module.bias, if hasattr(module, "bias") and isinstance(module.bias,
nn.Parameter): nn.Parameter):
# Remove the bias term in Linear and LayerNorm. # Remove the bias term in Linear and LayerNorm.
module.register_parameter("bias", None) module.register_parameter("bias", None)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.d_model))
def forward( def forward(
self, self,
...@@ -319,21 +328,28 @@ class DbrxModel(nn.Module): ...@@ -319,21 +328,28 @@ class DbrxModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.wte(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
for i in range(len(self.blocks)): if get_pp_group().is_first_rank:
hidden_states = self.wte(input_ids)
else:
assert intermediate_tensors
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
block = self.blocks[i] block = self.blocks[i]
hidden_states = block( hidden_states = block(
position_ids, position_ids,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm_f(hidden_states) hidden_states = self.norm_f(hidden_states)
return hidden_states return hidden_states
class DbrxForCausalLM(nn.Module): class DbrxForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -359,6 +375,8 @@ class DbrxForCausalLM(nn.Module): ...@@ -359,6 +375,8 @@ class DbrxForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -367,9 +385,9 @@ class DbrxForCausalLM(nn.Module): ...@@ -367,9 +385,9 @@ class DbrxForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -401,11 +419,20 @@ class DbrxForCausalLM(nn.Module): ...@@ -401,11 +419,20 @@ class DbrxForCausalLM(nn.Module):
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, weight_name) weight_loader(param, loaded_weight, weight_name)
break break
else: else:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -29,11 +29,12 @@ import torch ...@@ -29,11 +29,12 @@ import torch
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from .utils import is_pp_missing_parameter
class DeciLMForCausalLM(LlamaForCausalLM): class DeciLMForCausalLM(LlamaForCausalLM):
""" """
...@@ -91,6 +92,8 @@ class DeciLMForCausalLM(LlamaForCausalLM): ...@@ -91,6 +92,8 @@ class DeciLMForCausalLM(LlamaForCausalLM):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -99,6 +102,8 @@ class DeciLMForCausalLM(LlamaForCausalLM): ...@@ -99,6 +102,8 @@ class DeciLMForCausalLM(LlamaForCausalLM):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment