Commit ec5e299c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.3' into v0.7.3-dev

parents 47bd229c ed6e9075
{
"1": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"8": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"16": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"24": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"32": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"48": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"64": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"96": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"128": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"256": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"512": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1024": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1536": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2048": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"3072": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4096": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"8": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"16": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"24": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"32": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"48": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"64": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"96": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"128": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"256": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"512": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1024": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1536": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2048": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"3072": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4096": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
}
}
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
import re
from copy import deepcopy
from typing import Dict, Optional, Union
import torch
from vllm.config import QuantizationConfig
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, UnquantizedEmbeddingMethod)
# Match dynamic rules with module name (prefix) and override quantize
# config if module (prefix) matches a rule
def override_config(config: QuantizationConfig, prefix: str):
weight_bits = get_dynamic_override(config, prefix, "bits",
config.weight_bits)
if isinstance(weight_bits, int):
config.weight_bits = weight_bits
group_size = get_dynamic_override(config, prefix, "group_size",
config.group_size)
if isinstance(group_size, int):
config.group_size = group_size
desc_act = get_dynamic_override(config, prefix, "desc_act",
config.desc_act)
if isinstance(desc_act, bool):
config.desc_act = desc_act
config.pack_factor = 32 // config.weight_bits # packed into int32
if config.get_name() == "gptq_marlin":
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
if isinstance(is_sym, bool):
config.is_sym = is_sym
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
raise ValueError("Unsupported quantization config: "
f"bits={config.weight_bits}, sym={config.is_sym}")
config.quant_type = config.TYPE_MAP[(config.weight_bits,
config.is_sym)]
elif config.get_name() == "gptq":
if config.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {config.weight_bits} bits.")
def get_dynamic_override(
config: QuantizationConfig,
layer_name: str,
key: Optional[str] = None,
default_value: Union[int, bool,
None] = None) -> Union[Dict, int, bool, None]:
for pattern, pattern_dict in config.dynamic.items():
# Negative match: matched modules are excluded from quantized init
if pattern.startswith("-:"):
if re.match(pattern.removeprefix("-:"), layer_name):
return False
# Positive match: matched modules have quant properties overrides
# base quant config
elif re.match(pattern.removeprefix("+:"), layer_name):
if key is None:
return pattern_dict
else:
return pattern_dict.get(key, default_value)
return default_value
def get_linear_quant_method(
config: QuantizationConfig,
layer: torch.nn.Module,
prefix: str,
linear_method_cls: type,
):
cloned_config = deepcopy(config)
parallel_lm_head_quantized = isinstance(
layer, ParallelLMHead) and cloned_config.lm_head_quantized
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
# False = skip module, None = no override, else = Positive match
if get_dynamic_override( # noqa: E712
cloned_config, # noqa: E712
layer_name=prefix) == False: # noqa: E712
if parallel_lm_head_quantized:
return UnquantizedEmbeddingMethod()
return UnquantizedLinearMethod()
if prefix:
# Dynamic per module/layer rules may override base config
override_config(cloned_config, prefix=prefix)
return linear_method_cls(cloned_config)
return None
......@@ -6,6 +6,7 @@ import numpy
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
......@@ -135,6 +136,20 @@ def check_marlin_supports_shape(output_size_per_partition: int,
return True, None
def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
-> bool:
output_size_per_partition = getattr(layer, "output_size_per_partition",
None) or layer.output_size
input_size_per_partition = getattr(layer, "input_size_per_partition",
None) or layer.input_size
return check_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=layer.input_size,
group_size=group_size)[0]
def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition //
......
......@@ -10,9 +10,15 @@ from vllm.utils import W8a8GetCacheJSON
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
TORCH_DEVICE_IDENTITY = None
W8A8_TRITONJSON=W8a8GetCacheJSON()
# The condition to determine if it is on a platform that supports
# torch._scaled_mm rowwise feature.
# The condition is determined once as the operations
# are time consuming.
USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm()
and current_platform.has_device_capability(94))
def sparse_cutlass_supported() -> bool:
if not current_platform.is_cuda():
......@@ -108,6 +114,13 @@ def requantize_with_max_scale(
return max_w_scale, weight
def maybe_create_device_identity():
# Allocate dummy ones tensor for torch._scaled_mm
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
def apply_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
......@@ -174,6 +187,26 @@ def apply_fp8_linear(
return torch.narrow(output, 0, 0,
input_2d.shape[0]).view(*output_shape)
elif (use_per_token_if_dynamic and not per_tensor_weights
and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM):
# For now validated on ROCm platform
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
# For CUDA platform please validate if the
# torch._scaled_mm support rowwise scaled GEMM
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale.t(),
bias=bias)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
output = output.view(*output_shape)
return output
else:
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
......@@ -190,11 +223,6 @@ def apply_fp8_linear(
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
......
......@@ -206,9 +206,10 @@ class RotaryEmbedding(CustomOp):
) -> Tuple[torch.Tensor, torch.Tensor]:
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
positions = positions.flatten()
if offsets is not None:
offsets = offsets.view(positions.shape[0], -1)
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions).view(
num_tokens, 1, -1)
......@@ -509,15 +510,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
):
super().__init__()
if rotary_dim != head_size:
raise ValueError(
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
rotary_dim != head_size ({rotary_dim}!={head_size}).")
if is_neox_style is False:
raise ValueError(
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
)
self.rotary_dim = rotary_dim
self.head_size = head_size
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
......@@ -557,7 +555,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
0, self.head_size, 2, dtype=torch.float) / self.head_size)))
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)))
return inv_freq
def _compute_cos_sin_cache(
......@@ -596,8 +594,15 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
cos = cos.repeat(1, 2).unsqueeze(-2)
sin = sin.repeat(1, 2).unsqueeze(-2)
query = query * cos + _rotate_neox(query) * sin
key = key * cos + _rotate_neox(key) * sin
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = query_rot * cos + _rotate_neox(query_rot) * sin
query = torch.cat((query_rot, query_pass), dim=-1)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = key_rot * cos + _rotate_neox(key_rot) * sin
key = torch.cat((key_rot, key_pass), dim=-1)
return query.flatten(-2), key.flatten(-2)
......
......@@ -68,7 +68,6 @@ class SampleResultArgsType:
sample_results_dict: SampleResultsDictType
sampling_metadata: SamplingMetadata
greedy_samples: Optional[torch.Tensor]
beam_search_logprobs: Optional[torch.Tensor]
# Union of non-deferred (single-step scheduling)
......@@ -523,74 +522,6 @@ def _random_sample(
return results
def _beam_search_sample(
selected_seq_groups: List[SequenceGroupToSample],
logprobs: torch.Tensor,
) -> SampleResultType:
"""Run beam sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
on selected sample indices.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# We sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
# for details. See also HF reference:
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
#
# NOTE: Beam search is not vectorized, so its speed can be slower than
# other sampling methods.
sample_idx = 0
results: SampleResultType = []
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
continue
is_prompt = seq_group.is_prompt
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
num_parent_seqs = len(seq_ids)
beam_width = sampling_params.n
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
if is_prompt:
# Prompt phase.
assert num_parent_seqs == 1, (
"Prompt input should have only one seq.")
parent_ids = [0] * (2 * beam_width)
_, next_token_ids = torch.topk(seq_group_logprobs[0],
2 * beam_width)
next_token_ids = next_token_ids.tolist()
else:
# Generation phase.
cumulative_logprobs: List[float] = [
seq_group.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids
]
cumulative_logprobs_tensor = torch.tensor(
cumulative_logprobs,
dtype=torch.float,
device=seq_group_logprobs.device)
seq_group_logprobs = (seq_group_logprobs +
cumulative_logprobs_tensor.unsqueeze(dim=1))
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
2 * beam_width)
topk_ids = topk_ids.tolist()
vocab_size = seq_group_logprobs.size(-1)
parent_ids = [i // vocab_size for i in topk_ids]
next_token_ids = [i % vocab_size for i in topk_ids]
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0)
return results
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead.
# Note that we always sample with replacement.
......@@ -679,14 +610,12 @@ def get_pythonized_sample_results(
sampling_metadata,
greedy_samples,
multinomial_samples,
beam_search_logprobs,
sample_results_dict,
) = (
sample_result_args.sample_metadata,
sample_result_args.sampling_metadata,
sample_result_args.greedy_samples,
sample_result_args.multinomial_samples,
sample_result_args.beam_search_logprobs,
sample_result_args.sample_results_dict,
)
......@@ -699,9 +628,6 @@ def get_pythonized_sample_results(
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_id, sample_results))
return [
......@@ -744,7 +670,6 @@ def _sample_with_torch(
sample_metadata: SampleMetadataType = {}
multinomial_samples: MultinomialSamplesType = {}
greedy_samples: Optional[torch.Tensor] = None
beam_search_logprobs: Optional[torch.Tensor] = None
# Create output tensor for sampled token ids.
if include_gpu_probs_tensor:
......@@ -813,8 +738,6 @@ def _sample_with_torch(
sampled_token_ids_tensor[long_sample_indices] = \
multinomial_samples[sampling_type].to(torch.long)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
......@@ -825,7 +748,6 @@ def _sample_with_torch(
sample_metadata=sample_metadata,
multinomial_samples=multinomial_samples,
greedy_samples=greedy_samples,
beam_search_logprobs=beam_search_logprobs,
sample_results_dict=sample_results_dict)
if not sampling_metadata.skip_sampler_cpu_output:
......@@ -971,7 +893,9 @@ def get_logprobs(
if len(query_indices) == 0:
empty_sampled_logprob: SampleLogprobs = []
empty_prompt_logprob: Optional[PromptLogprobs] = None
return [empty_prompt_logprob], [empty_sampled_logprob]
num_seq_groups = len(sampling_metadata.seq_groups)
return [empty_prompt_logprob
] * num_seq_groups, [empty_sampled_logprob] * num_seq_groups
selected_logprobs, ranks = None, None
top_logprobs, top_token_ids = None, None
......@@ -1239,6 +1163,10 @@ def _build_sampler_output(
assert sample_logprobs is not None
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
assert len(sampling_metadata.seq_groups) \
== len(maybe_deferred_sample_results) \
== len(prompt_logprobs) \
== len(sample_logprobs)
deferred_sample_results_args = None
for (seq_group, sample_result, group_prompt_logprobs,
......
......@@ -45,7 +45,7 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
vocab_size, num_seqs)
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs)
repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat(
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, vocab_size)
logits[logits > 0] /= torch.where(prompt_mask | output_mask,
repetition_penalties, 1.0)[logits > 0]
......@@ -53,6 +53,6 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
repetition_penalties, 1.0)[logits <= 0]
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits
......@@ -241,24 +241,24 @@ class VocabParallelEmbedding(torch.nn.Module):
self.tp_size)
self.embedding_dim = embedding_dim
linear_method = None
quant_method = None
if quant_config is not None:
linear_method = quant_config.get_quant_method(self, prefix=prefix)
if linear_method is None:
linear_method = UnquantizedEmbeddingMethod()
quant_method = quant_config.get_quant_method(self, prefix=prefix)
if quant_method is None:
quant_method = UnquantizedEmbeddingMethod()
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
linear_method_implements_embedding = method_has_implemented_embedding(
type(linear_method))
if is_embedding_layer and not linear_method_implements_embedding:
quant_method_implements_embedding = method_has_implemented_embedding(
type(quant_method))
if is_embedding_layer and not quant_method_implements_embedding:
raise NotImplementedError(
f"The class {type(linear_method).__name__} must implement "
f"The class {type(quant_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
self.linear_method: QuantizeMethodBase = linear_method
self.quant_method: QuantizeMethodBase = quant_method
if params_dtype is None:
params_dtype = torch.get_default_dtype()
......@@ -275,13 +275,13 @@ class VocabParallelEmbedding(torch.nn.Module):
self.shard_indices.added_vocab_end_index -
self.shard_indices.added_vocab_start_index)
self.linear_method.create_weights(self,
self.embedding_dim,
[self.num_embeddings_per_partition],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
self.quant_method.create_weights(self,
self.embedding_dim,
[self.num_embeddings_per_partition],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
@classmethod
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
......@@ -427,8 +427,8 @@ class VocabParallelEmbedding(torch.nn.Module):
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.linear_method.embedding(self,
masked_input.long())
output_parallel = self.quant_method.embedding(self,
masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
......
......@@ -155,6 +155,30 @@ def _initialize_model(
return model_class(**kwargs)
def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
target_device: torch.device) -> None:
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
# Currently only used by MLA.
# NOTE: This intentionally happens after other modules so we can easily
# decompress the weights for MLA.
for _, module in model.named_modules():
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)
class BaseModelLoader(ABC):
"""Base class for model loaders."""
......@@ -378,7 +402,6 @@ class DefaultModelLoader(BaseModelLoader):
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
......@@ -396,23 +419,8 @@ class DefaultModelLoader(BaseModelLoader):
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase) and quant_method != "awq" and quant_method != "gptq" and quant_method != "compressed_tensors":
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)
_process_weights_after_loading(model, model_config, target_device)
return model.eval()
......@@ -431,29 +439,15 @@ class DummyModelLoader(BaseModelLoader):
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(
module, torch.device(device_config.device)):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(model_config.dtype)
_process_weights_after_loading(model, model_config, target_device)
return model.eval()
......@@ -634,6 +628,7 @@ class ShardedStateLoader(BaseModelLoader):
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
from safetensors.torch import safe_open
from vllm.distributed import get_tensor_model_parallel_rank
......@@ -642,18 +637,10 @@ class ShardedStateLoader(BaseModelLoader):
model_config.revision)
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(
model_config.dtype)
_process_weights_after_loading(model, model_config,
target_device)
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
local_model_path,
......@@ -1342,6 +1329,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
is_s3_path = is_s3(model_name_or_path)
is_local = os.path.isdir(model_name_or_path)
safetensors_pattern = "*.safetensors"
......@@ -1355,7 +1343,6 @@ class RunaiModelStreamerLoader(BaseModelLoader):
revision,
ignore_patterns=self.load_config.ignore_patterns,
))
if is_s3_path:
hf_weights_files = s3_glob(path=hf_folder,
allow_pattern=[safetensors_pattern])
......@@ -1403,16 +1390,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
self._get_weights_iterator(model_weights,
model_config.revision))
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(model_config.dtype)
_process_weights_after_loading(model, model_config, target_device)
return model.eval()
......
......@@ -48,22 +48,31 @@ def resolve_transformers_fallback(model_config: ModelConfig,
for i, arch in enumerate(architectures):
if arch == "TransformersModel":
continue
custom_module = None
auto_map = getattr(model_config.hf_config, "auto_map", None)
if auto_map is not None and "AutoModel" in auto_map:
custom_module = get_class_from_dynamic_module(
model_config.hf_config.auto_map["AutoModel"],
model_config.model)
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
None) or dict()
# Make sure that config class is always initialized before model class,
# otherwise the model class won't be able to access the config class,
# the expected auto_map should have correct order like:
# "auto_map": {
# "AutoConfig": "<your-repo-name>--<config-name>",
# "AutoModel": "<your-repo-name>--<config-name>",
# "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
# },
auto_modules = {
name: get_class_from_dynamic_module(module, model_config.model)
for name, module in sorted(auto_map.items(), key=lambda x: x[0])
}
custom_model_module = auto_modules.get("AutoModel")
# TODO(Isotr0py): Further clean up these raises.
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
if model_config.model_impl == ModelImpl.TRANSFORMERS:
if not is_transformers_impl_compatible(arch, custom_module):
if not is_transformers_impl_compatible(arch, custom_model_module):
raise ValueError(
f"The Transformers implementation of {arch} is not "
"compatible with vLLM.")
architectures[i] = "TransformersModel"
if model_config.model_impl == ModelImpl.AUTO:
if not is_transformers_impl_compatible(arch, custom_module):
if not is_transformers_impl_compatible(arch, custom_model_module):
raise ValueError(
f"{arch} has no vLLM implementation and the Transformers "
"implementation is not compatible with vLLM.")
......
......@@ -6,6 +6,7 @@ import hashlib
import json
import os
import tempfile
import time
from collections import defaultdict
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
......@@ -237,7 +238,8 @@ def download_weights_from_hf(
Returns:
str: The path to the downloaded model weights.
"""
if not huggingface_hub.constants.HF_HUB_OFFLINE:
local_only = huggingface_hub.constants.HF_HUB_OFFLINE
if not local_only:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
......@@ -253,6 +255,7 @@ def download_weights_from_hf(
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
start_time = time.perf_counter()
hf_folder = snapshot_download(
model_name_or_path,
allow_patterns=allow_patterns,
......@@ -260,8 +263,12 @@ def download_weights_from_hf(
cache_dir=cache_dir,
tqdm_class=DisabledTqdm,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
local_files_only=local_only,
)
time_taken = time.perf_counter() - start_time
if time_taken > 0.5:
logger.info("Time spent downloading weights for %s: %.6f seconds",
model_name_or_path, time_taken)
return hf_folder
......@@ -453,7 +460,6 @@ def pt_weights_iterator(
state = torch.load(bin_file, map_location="cpu", weights_only=True)
yield from state.items()
del state
torch.cuda.empty_cache()
def get_gguf_extra_tensor_names(
......
......@@ -33,7 +33,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.arctic import ArcticConfig
from .interfaces import SupportsPP
from .interfaces import SupportsPP, SupportsQuant
from .utils import (extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
......@@ -423,7 +423,8 @@ class ArcticModel(nn.Module):
return hidden_states
class ArcticForCausalLM(nn.Module, SupportsPP):
class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
......@@ -36,7 +36,7 @@ from .idefics2_vision_model import Idefics2VisionConfig
from .idefics2_vision_model import (
Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable
from .interfaces import SupportsMultiModal
from .interfaces import SupportsMultiModal, SupportsQuant
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter, maybe_prefix,
......@@ -53,7 +53,8 @@ class AriaImagePixelInputs(TypedDict):
"""
class AriaVisionTransformer(Idefics3VisionTransformer):
class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
def __init__(
self,
......@@ -304,11 +305,17 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
self.mlp = AriaTextMoELayer(config, quant_config=quant_config)
class AriaTextModel(LlamaModel):
class AriaTextModel(LlamaModel, SupportsQuant):
"""
Custom LlamaModel for the AriaMoE model which modifies the standard
LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`.
"""
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
"experts.w13_weight": ["experts.fc1.weight"],
"experts.w2_weight": ["experts.fc2.weight"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
......@@ -393,8 +400,8 @@ class AriaProcessingInfo(BaseProcessingInfo):
def get_vision_config(self):
return self.get_hf_config().vision_config
def get_hf_processor(self):
return self.ctx.get_hf_processor(AriaProcessor)
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(AriaProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
......
# SPDX-License-Identifier: Apache-2.0
"""Inference-only Bamba model."""
# Added by the IBM Team, 2024
from typing import Iterable, List, Optional, Set, Tuple
import torch
from torch import nn
from transformers import BambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
KVCache = Tuple[torch.Tensor, torch.Tensor]
class BambaMLP(nn.Module):
def __init__(
self,
config: BambaConfig,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=config.hidden_size,
output_sizes=[config.intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
input_size=config.intermediate_size,
output_size=config.hidden_size,
bias=bias,
quant_config=quant_config,
)
if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
class BambaMixerDecoderLayer(nn.Module):
def __init__(self,
config: BambaConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__()
self.config = config
self.mamba = MambaMixer2(hidden_size= config.hidden_size,
ssm_state_size = config.mamba_d_state,
conv_kernel_size = config.mamba_d_conv,
intermediate_size = config.mamba_expand *\
config.hidden_size,
use_conv_bias = config.mamba_conv_bias,
use_bias = config.mamba_proj_bias,
n_groups=config.mamba_n_groups,
num_heads=config.mamba_n_heads,
head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
chunk_size=config.mamba_chunk_size,
quant_config=quant_config)
self.feed_forward = BambaMLP(config, quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor] = None,
**kwargs,
):
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.mamba(hidden_states, attn_metadata,
mamba_cache_params, sequence_idx)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
class BambaAttentionDecoderLayer(nn.Module):
def __init__(
self,
config: BambaConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if hasattr(config, "partial_rotary_factor"):
rotary_dim = self.head_dim * config.partial_rotary_factor
elif hasattr(config, "attn_rotary_emb"):
rotary_dim = config.attn_rotary_emb # for backward compatibility
else:
rotary_dim = self.head_dim # default
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
rope_scaling=rope_scaling,
base=rope_theta,
is_neox_style=True,
dtype=torch.get_default_dtype(), # see impl of get_rope
)
self.qkv_proj = QKVParallelLinear(
config.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
prefix=f"{prefix}.attn",
)
self.feed_forward = BambaMLP(config, quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def self_attention(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
**kwargs,
):
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attention(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
ALL_DECODER_LAYER_TYPES = {
"attention": BambaAttentionDecoderLayer,
"mamba": BambaMixerDecoderLayer
}
class BambaModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
def get_layer(prefix: str):
layer_idx = int(prefix.rsplit(".", 1)[1])
layer_class = ALL_DECODER_LAYER_TYPES[
config.layers_block_type[layer_idx]]
return layer_class(
config,
layer_idx,
cache_config,
quant_config=quant_config,
prefix=prefix,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.final_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
seq_idx = None
if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate(
zip(
attn_metadata.query_start_loc,
attn_metadata.query_start_loc[1:],
)):
seq_idx[srt:end] = i
seq_idx.unsqueeze_(0)
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
residual = None
num_attn = 0
for i in range(len(self.layers)):
layer = self.layers[i]
kv_cache = None
if isinstance(layer, BambaAttentionDecoderLayer):
kv_cache = kv_caches[num_attn]
num_attn += 1
layer_mamba_cache_params = None
if isinstance(layer, BambaMixerDecoderLayer):
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_attn)
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=residual,
mamba_cache_params=layer_mamba_cache_params,
sequence_idx=seq_idx,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.final_layernorm(hidden_states, residual)
return hidden_states
class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["up_proj", "down_proj"]
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
"Bamba currently does not support prefix caching"
self.quant_config = vllm_config.quant_config
super().__init__()
self.config = config
self.scheduler_config = scheduler_config
self.model = BambaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
conv_state_shape, temporal_state_shape = None, None
intermediate_size = self.config.mamba_expand * hidden_size
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards(
self.config.mamba_n_groups, world_size))
# - heads and n_groups are TP-ed
conv_dim = (intermediate_size +
2 * n_groups * self.config.mamba_d_state)
conv_state_shape = (
divide(conv_dim, world_size),
self.config.mamba_d_conv - 1,
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
divide(self.config.mamba_n_heads, world_size),
self.config.mamba_d_head,
self.config.mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "A_log" in name:
name = name.replace("A_log", "A")
if ".self_attn." in name:
name = name.replace(".self_attn", "")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
......@@ -58,8 +58,8 @@ class ChameleonProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(ChameleonConfig)
def get_hf_processor(self):
return self.ctx.get_hf_processor(ChameleonProcessor)
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(ChameleonProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
......
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/THUDM/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights."""
from argparse import Namespace
from array import array
from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict)
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from PIL import Image
from torch import nn
from torch.nn import LayerNorm
import os
......@@ -18,9 +13,6 @@ import re
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
......@@ -33,21 +25,11 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA, SupportsMultiModal
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
......@@ -56,185 +38,6 @@ from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
logger = init_logger(__name__)
def calculate_image_placeholder(vision_config):
return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2
def mm_input_mapper_for_glmv(
ctx: InputContext,
data: ModalityData[object],
) -> Dict:
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
if tokenizer is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
try:
raw_batch_data = tokenizer.apply_chat_template(
conversation=[{
"role": "user",
"image": data
}],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True).data
except Exception:
logger.error("Failed to process image (%s)", data)
raise
pixel_values = raw_batch_data['images']
return MultiModalKwargs({'pixel_values': pixel_values})
def merge_glm_vision_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
boi_token_id: int,
eoi_token_id: int,
) -> torch.Tensor:
boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0]
eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0]
mask = torch.zeros_like(input_ids, dtype=torch.bool)
for boi_pos, eoi_pos in zip(boi_positions, eoi_positions):
assert boi_pos < eoi_pos
mask[boi_pos:eoi_pos + 1] = True
inputs_embeds[mask] = vision_embeddings.view(-1,
vision_embeddings.shape[-1])
return inputs_embeds
class GLMImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
def get_max_glmv_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(ChatGLMConfig)
vision_config = getattr(hf_config, 'vision_config', None)
if vision_config is None:
return 1
elif isinstance(vision_config, dict):
return calculate_image_placeholder(vision_config)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def dummy_data_for_glmv(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]) -> DummyData:
hf_config = ctx.get_hf_config(ChatGLMConfig)
vision_config = getattr(hf_config, 'vision_config', None)
if vision_config is None:
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
seq_data = SequenceData(token_ids)
return DummyData(seq_data, None)
elif isinstance(vision_config, dict):
image_size = vision_config["image_size"]
image_placeholder_length = calculate_image_placeholder(vision_config)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] +
[0] * image_placeholder_length +
[hf_config.eoi_token_id])
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0] * (seq_len - image_placeholder_length - 2))
seq_data = SequenceData(token_ids)
mm_data = {
"image": Image.new("RGB", (image_size, image_size), color=0)
}
return DummyData(seq_data, mm_data)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def find_all_positions(input_ids: List[int], target: int) -> List[int]:
return [index for index, value in enumerate(input_ids) if value == target]
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
hf_config = ctx.get_hf_config(ChatGLMConfig)
vision_config = getattr(hf_config, 'vision_config', None)
if vision_config is None:
return inputs
elif isinstance(vision_config, dict):
image_placeholder_length = calculate_image_placeholder(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
input_ids = inputs["prompt_token_ids"]
tokenizer = cached_get_tokenizer(
ctx.model_config.model,
trust_remote_code=ctx.model_config.trust_remote_code)
try:
raw_batch_data = tokenizer.apply_chat_template(
conversation=[{
"role": "user",
"image": multi_modal_data["image"],
"content": inputs['prompt'],
}],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
).data
except Exception:
logger.error("Failed to process content (%s)", inputs['prompt'])
raise
input_ids = raw_batch_data['input_ids'][0].tolist()
boi_token_id = hf_config.boi_token_id
eoi_token_id = hf_config.eoi_token_id
boi_positions = find_all_positions(input_ids, boi_token_id)
eoi_positions = find_all_positions(input_ids, eoi_token_id)
assert len(boi_positions) == len(eoi_positions)
new_input_ids = []
final_processed_position = 0
for boi_position, eoi_position in zip(boi_positions, eoi_positions):
assert boi_position < eoi_position
new_input_ids.extend(input_ids[final_processed_position:boi_position +
1])
new_input_ids.extend([input_ids[boi_position + 1]] *
image_placeholder_length)
final_processed_position = eoi_position
new_input_ids.extend(input_ids[final_processed_position:])
prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(new_input_ids)
return token_inputs(
prompt_token_ids=new_input_ids,
prompt=prompt,
multi_modal_data=multi_modal_data,
)
class GLMAttention(nn.Module):
def __init__(
......@@ -500,7 +303,7 @@ class GLMTransformer(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(
......@@ -509,8 +312,12 @@ class GLMTransformer(nn.Module):
kv_cache=kv_caches[i - self.start_layer],
attn_metadata=attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
# Final layer norm.
if get_pp_group().is_last_rank and self.post_layer_norm:
if self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
......@@ -545,15 +352,6 @@ class ChatGLMModel(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.output_layer")
vision_config_flag = getattr(config, 'vision_config', None)
if vision_config_flag is not None:
self.vision_config = Namespace(**config.vision_config)
self.vision = EVA2CLIPModel(self.config,
quant_config,
prefix=f"{prefix}.vision")
else:
self.vision = None
self.make_empty_intermediate_tensors = (
self.encoder.make_empty_intermediate_tensors)
......@@ -566,45 +364,8 @@ class ChatGLMModel(nn.Module):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def _parse_and_validate_image_input(
self, **kwargs: object) -> GLMImagePixelInputs:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is not None and self.vision is not None:
if isinstance(pixel_values, torch.Tensor):
if pixel_values.ndim > 2:
pixel_values = torch.concat(list(pixel_values))
elif isinstance(pixel_values, list):
return torch.concat(pixel_values)
else:
raise TypeError("""pixel_values must be a torch.Tensor
or a list of torch.Tensor
""")
return GLMImagePixelInputs(pixel_values=pixel_values)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input["pixel_values"] is None:
return None
pixel_values = image_input["pixel_values"].to(
dtype=self.config.torch_dtype)
vision_embeddings = self.vision(pixel_values)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.embedding(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_glm_vision_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
vision_embeddings=multimodal_embeddings,
boi_token_id=self.config.boi_token_id,
eoi_token_id=self.config.eoi_token_id)
return inputs_embeds
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embedding(input_ids)
def forward(
self,
......@@ -615,28 +376,24 @@ class ChatGLMModel(nn.Module):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> torch.Tensor:
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
if intermediate_tensors is None and inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
else:
inputs_embeds = intermediate_tensors["hidden_states"]
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
# Run encoder.
hidden_states = self.encoder(
hidden_states=inputs_embeds,
hidden_states=hidden_states,
position_ids=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
......@@ -722,12 +479,18 @@ class ChatGLMModel(nn.Module):
return loaded_params
class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
class ChatGLMBaseModel(nn.Module):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={".word_embeddings": ""}, )
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
transformer_type: type[ChatGLMModel] = ChatGLMModel,
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
......@@ -740,9 +503,9 @@ class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length",
8192)
self.transformer = ChatGLMModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
self.transformer = transformer_type(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
if self.config.tie_word_embeddings:
self.transformer.output_layer.weight = (
self.transformer.embedding.weight)
......@@ -750,18 +513,8 @@ class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = get_sampler()
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
**kwargs)
return hidden_states
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def compute_logits(
self,
......@@ -785,7 +538,7 @@ class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
class ChatGLM(ChatGLMBaseModel):
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
......@@ -801,72 +554,28 @@ class ChatGLM(ChatGLMBaseModel):
embedding_modules = {}
embedding_padding_modules = []
class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"],
"merged_proj": ["gate_proj", "dense_h_to_4h"]
}
# LoRA specific attributes
supported_lora_modules = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
# vision
"fc1",
"fc2",
"merged_proj",
"linear_proj"
]
embedding_modules = {}
embedding_padding_modules = []
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="transformer.encoder",
connector="transformer.vision.linear_proj",
tower_model="transformer.vision.transformer")
@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
SupportsMultiModal):
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
# These will be updated when an instance class is selected
packed_modules_mapping = {}
supported_lora_modules = []
embedding_modules = {}
embedding_padding_modules = []
def __new__(
cls,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
if hasattr(config, "vision_config"):
hf_overrides = {"architectures": ["GLM4VForCausalLM"]}
raise RuntimeError(
"The configuration of this model indicates that it supports "
"vision inputs, but you instantiated the text-only version "
"of this model. Please use the vision model by setting "
f"`--hf-overrides {hf_overrides!r}`")
# Initialize VL
if hasattr(config, "vision_config"): # noqa: SIM108
instance_cls = ChatGLMV
# Initialize LLM
else:
instance_cls = ChatGLM
# quant_config references base class members,
# so update values before init is called
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.supported_lora_modules += instance_cls.supported_lora_modules
cls.embedding_modules.update(instance_cls.embedding_modules)
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
return instance_cls(vllm_config=vllm_config, prefix=prefix)
super().__init__(vllm_config=vllm_config, prefix=prefix)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
# SPDX-License-Identifier: Apache-2.0
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from typing import Iterable, List, Optional, Set, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig
from vllm.attention.layer import MultiHeadAttention
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData
import vllm.envs as envs
from vllm.model_executor.models.interfaces import SupportsQuant
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
return image_size // patch_size
def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
grid_length = get_clip_patch_grid_length(image_size=image_size,
patch_size=patch_size)
return grid_length * grid_length
def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
return get_clip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size) + 1
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
return get_clip_image_feature_size(hf_config)
def dummy_seq_data_for_clip(hf_config: CLIPVisionConfig,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
mm_key: str = "image"):
if image_feature_size_override is None:
image_feature_size = get_clip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
), {
mm_key:
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_clip(
hf_config: CLIPVisionConfig,
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
def dummy_video_for_clip(
hf_config: CLIPVisionConfig,
num_frames: int,
num_videos: int = 1,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
pil_frame = dummy_image_for_clip(
hf_config,
num_images=1,
image_width_override=image_width_override,
image_height_override=image_height_override)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
video_data = [mm_data_per_video] * num_videos
mm_data = {"video": video_data}
return mm_data
def input_processor_for_clip(
model_config: ModelConfig,
hf_config: CLIPVisionConfig,
inputs: DecoderOnlyInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[Union[int, List[int]]] = None,
):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None:
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_feature_size = get_clip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": ranges})
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
def get_num_image_tokens(
......@@ -160,10 +28,10 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
image_width: int,
image_height: int,
) -> int:
return get_clip_image_feature_size(self.vision_config)
return self.get_patch_grid_length()**2 + 1
def get_max_image_tokens(self) -> int:
return get_max_clip_image_tokens(self.vision_config)
return self.get_patch_grid_length()**2 + 1
def get_image_size(self) -> int:
return self.vision_config.image_size
......@@ -172,10 +40,9 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
return self.vision_config.patch_size
def get_patch_grid_length(self) -> int:
return get_clip_patch_grid_length(
image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size,
)
image_size, patch_size = self.get_image_size(), self.get_patch_size()
assert image_size % patch_size == 0
return image_size // patch_size
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
......@@ -187,6 +54,7 @@ class CLIPVisionEmbeddings(nn.Module):
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
assert self.image_size % self.patch_size == 0
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
......@@ -198,8 +66,7 @@ class CLIPVisionEmbeddings(nn.Module):
bias=False,
)
self.num_patches = get_clip_num_patches(image_size=self.image_size,
patch_size=self.patch_size)
self.num_patches = (self.image_size // self.patch_size)**2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)
......@@ -384,7 +251,7 @@ class CLIPEncoder(nn.Module):
def forward(
self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool
) -> Union[torch.Tensor, list[torch.Tensor]]:
hidden_states_pool = []
hidden_states_pool = [inputs_embeds]
hidden_states = inputs_embeds
for encoder_layer in self.layers:
......@@ -469,10 +336,10 @@ class CLIPVisionTransformer(nn.Module):
return encoder_outputs
class CLIPVisionModel(nn.Module):
class CLIPVisionModel(nn.Module, SupportsQuant):
config_class = CLIPVisionConfig
main_input_name = "pixel_values"
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
def __init__(
self,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Set, Tuple
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .deepseek_v2 import (DeepseekV2DecoderLayer,
get_spec_layer_idx_from_weight_name)
from .utils import maybe_prefix
class SharedHead(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(hidden_states)
class DeepSeekMultiTokenPredictorLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
cache_config, quant_config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds[positions == 0] = 0
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=None)
hidden_states = residual + hidden_states
return self.shared_head(hidden_states)
class DeepSeekMultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
# to map the exact layer index from weights
self.layers = torch.nn.ModuleDict({
str(idx):
DeepSeekMultiTokenPredictorLayer(
config,
f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
)
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})
self.logits_processor = LogitsProcessor(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
input_ids,
positions,
kv_caches[spec_step_idx],
attn_metadata,
previous_hidden_states,
inputs_embeds,
spec_step_idx,
)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> torch.Tensor:
mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]
logits = self.logits_processor(mtp_layer.shared_head.head,
hidden_states, sampling_metadata)
return logits
class DeepSeekMTP(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, previous_hidden_states,
inputs_embeds, spec_step_idx)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> Optional[torch.Tensor]:
return self.model.compute_logits(hidden_states, sampling_metadata,
spec_step_idx)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
"""
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
"""
spec_layer_weight_names = [
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
]
spec_layer_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
break
if not spec_layer_weight:
# treat rest weights as weights for transformer layer block
name = name.replace(f"model.layers.{spec_layer}.",
f"model.layers.{spec_layer}.mtp_block.")
return name
......@@ -262,9 +262,7 @@ class DeepseekV2Attention(nn.Module):
prefix=f"{prefix}.o_proj")
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.use_normal_rope = False
else:
self.use_normal_rope = True
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
......@@ -314,17 +312,8 @@ class DeepseekV2Attention(nn.Module):
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:]
if self.use_normal_rope:
seq_len = positions.size(0)
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe = q_pe.reshape(seq_len, -1)
k_pe = k_pe.reshape(seq_len, -1)
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
if self.use_normal_rope:
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
......@@ -599,7 +588,8 @@ class DeepseekV2Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens")
else:
self.embed_tokens = PPMissingLayer()
......@@ -773,13 +763,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
if "rotary_emb.inv_freq" in name:
continue
# TODO(simon): support nextn predict layers
if hasattr(self.config, "num_nextn_predict_layers"
) and self.config.num_nextn_predict_layers > 0:
assert self.config.num_nextn_predict_layers == 1
layer_idx = self.config.num_hidden_layers
if name.startswith(f"model.layers.{layer_idx}"):
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
......@@ -927,4 +913,16 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass
\ No newline at end of file
pass
def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
weight_name: str) -> Optional[int]:
if hasattr(config,
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
> 0):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
return layer_idx + i
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment