Commit dcb5624a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-dev

parents 55880ca2 ba41cc90
......@@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -255,7 +254,6 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
self.lm_head = self.lm_head.tie_weights(self.transformer.wte)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
......@@ -282,14 +280,6 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
......
......@@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -43,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
......@@ -244,6 +243,30 @@ class GPTBigCodeModel(nn.Module):
hidden_states = self.ln_f(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if ".attn.bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
weight_loader(param, loaded_weight, 'q')
weight_loader(param, loaded_weight, 'k')
weight_loader(param, loaded_weight, 'v')
else:
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {"c_attn": ["c_attn"]}
......@@ -278,7 +301,6 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
......@@ -305,36 +327,10 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "lm_head.weight" in name:
continue
if ".attn.bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
weight_loader(param, loaded_weight, 'q')
weight_loader(param, loaded_weight, 'k')
weight_loader(param, loaded_weight, 'v')
else:
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]),
)
return loader.load_weights(weights)
\ No newline at end of file
......@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
......@@ -43,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
......@@ -188,6 +187,7 @@ class GPTJModel(nn.Module):
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.embed_dim = config.n_embd
self.wte = VocabParallelEmbedding(
config.vocab_size,
......@@ -228,61 +228,6 @@ class GPTJModel(nn.Module):
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPTJForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
assert not config.tie_word_embeddings
self.transformer = GPTJModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
self.lm_head = ParallelLMHead(
config.vocab_size,
config.n_embd,
bias=True,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions,
intermediate_tensors, inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
......@@ -339,3 +284,54 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class GPTJForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
assert not config.tie_word_embeddings
self.transformer = GPTJModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
self.lm_head = ParallelLMHead(
config.vocab_size,
config.n_embd,
bias=True,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions,
intermediate_tensors, inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
\ No newline at end of file
......@@ -36,7 +36,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -356,7 +355,6 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
if self.config.tie_word_embeddings:
self.embed_out.weight = self.gpt_neox.embed_in.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.gpt_neox.make_empty_intermediate_tensors)
......@@ -383,14 +381,6 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
......
......@@ -41,7 +41,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
......@@ -441,8 +440,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
else:
self.lm_head = PPMissingLayer()
self.sampler = get_sampler()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
......@@ -464,14 +461,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
......
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2025 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM Granite speeech model."""
import math
from typing import Iterable, Mapping, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers import BatchFeature, PretrainedConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .blip2 import Blip2QFormerModel
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, embed_multimodal,
init_vllm_registered_model, maybe_prefix)
### Audio Input
class GraniteSpeechAudioInputs(TypedDict):
input_features: torch.Tensor
"""Shape: `(bsz, num_features, 160)`"""
input_features_mask: torch.Tensor
"""Shape: `(bsz, num_features)`"""
audio_embed_sizes: list[int]
"""List of length `bsz`"""
class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": 1}
# There is no limit to the maximum number of audio tokens that can be
# encoded as features; we pick ~5000 as a number that is probably higher
# than we would expect to encounter. The sequence of length
# get_max_audio_len() produces get_max_audio_tokens().
def get_max_audio_tokens(self):
return 5001
def get_max_audio_len(self):
return 8000000
### Input Processing & Multimodal utils
class GraniteSpeechMultiModalProcessor(
BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_hf_processor().audio_processor
sampling_rate = feature_extractor.melspec_kwargs["sample_rate"]
return MultiModalDataParser(target_sr=sampling_rate)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
input_features=MultiModalFieldConfig.batched("audio"),
audio_embed_sizes=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
feature_extractor = processor.audio_processor
vocab = tokenizer.get_vocab()
# Use getattr with default to be compatible with transformers<4.48
audio_token = getattr(processor, "audio_token", "<|audio|>")
audio_token_id = vocab[audio_token]
def get_replacement(item_idx: int):
audios = mm_items.get_items("audio", AudioProcessorItems)
audio = audios.get(item_idx)
audio_length = audio.shape[-1]
num_projector_features = feature_extractor._get_num_audio_features(
[audio_length])[0]
return [audio_token_id] * num_projector_features
return [
PromptReplacement(
modality="audio",
target=[audio_token_id],
replacement=get_replacement,
)
]
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_data = dict(mm_data)
audios = mm_data.pop("audios", [])
if audios:
# GraniteSpeechFeatureExtractor accepts "audio"
mm_data["audio"] = audios
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
if "audio" in mm_data:
# Calculate the number of audio tokens per entry in the batch;
# This is used to split the batch back out after padding.
audio_token_index = self.info.get_hf_config().audio_token_index
processed_outputs["audio_embed_sizes"] = [
torch.sum(indices == audio_token_index).item()
for indices in processed_outputs["input_ids"]
]
return processed_outputs
class GraniteSpeechDummyInputsBuilder(
BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo]):
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0)
return {
"audio":
self._get_dummy_audios(
length=self.info.get_max_audio_len(),
num_audios=num_audios,
)
}
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
hf_processor = self.info.get_hf_processor()
audio_token = getattr(hf_processor, "audio_token", "<|audio|>")
return audio_token * num_audios
### QFormer Projector
class GraniteSpeechEncoderProjector(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: CacheConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.projector_config.hidden_size
self.downsample_rate = config.downsample_rate
self.window_size = config.window_size
self.num_queries = config.window_size // config.downsample_rate
self.query = nn.Parameter(
torch.zeros(1, self.num_queries,
config.projector_config.hidden_size))
# NOTE - this is implemented generically in transformers,
# but for now we create the QFormer model directly since
# all existing models use this for the projector.
self.qformer = Blip2QFormerModel(
config.projector_config,
quant_config=quant_config,
cache_config=cache_config,
prefix=f"{prefix}.qformer",
)
self.linear = nn.Linear(config.projector_config.hidden_size,
config.text_config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, dim = hidden_states.size()
nblocks = math.ceil(seq_len / self.window_size)
pad = nblocks * self.window_size - seq_len
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad),
"constant", 0)
hidden_states = hidden_states.view(batch_size * nblocks,
self.window_size, dim)
last_hidden_state = self.qformer(
query_embeds=self.query.data,
encoder_hidden_states=hidden_states,
)
query_proj = self.linear(
last_hidden_state.view(
batch_size,
nblocks * self.window_size // self.downsample_rate,
-1,
))
return query_proj
# Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git
# NOTE - it would be nice to see if we can align this with other models using
# conformer in vLLM, e.g., phi4mm audio.
class GraniteSpeechConformerFeedForward(nn.Module):
"""Feedforward module for conformer encoder blocks."""
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.pre_norm = nn.LayerNorm(config.hidden_dim)
self.up_proj = ColumnParallelLinear(
input_size=config.hidden_dim,
output_size=config.hidden_dim * config.feedforward_mult,
quant_config=quant_config,
prefix=f"{prefix}.up_proj",
)
self.silu = nn.SiLU()
self.down_proj = RowParallelLinear(
input_size=config.hidden_dim * config.feedforward_mult,
output_size=config.hidden_dim,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.pre_norm(hidden_states)
hidden_states, _ = self.up_proj(hidden_states)
hidden_states = self.silu(hidden_states)
hidden_states, _ = self.down_proj(hidden_states)
return hidden_states
class GraniteSpeechConformerAttention(nn.Module):
"""Attention for conformer blocks using Shaw's relative positional
embeddings. See the following [paper](https://arxiv.org/pdf/1803.02155)
for more details.
"""
def __init__(self, config: PretrainedConfig, prefix: str = ""):
super().__init__()
inner_dim = config.dim_head * config.num_heads
self.max_pos_emb = config.max_pos_emb
self.context_size = config.context_size
self.num_heads = config.num_heads
self.dim_head = config.dim_head
self.scale = self.dim_head**-0.5
self.pre_norm = nn.LayerNorm(config.hidden_dim)
self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False)
self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, config.hidden_dim)
self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1,
self.dim_head)
if self.context_size <= 0 or self.context_size > self.max_pos_emb:
raise ValueError(
"Context size is either less than 0 or exceeds the max_pos_emb"
)
def forward(self, hidden_states: torch.Tensor,
attention_dists: torch.Tensor) -> torch.Tensor:
hidden_states = self.pre_norm(hidden_states)
bsz, num_features, _ = hidden_states.shape
num_blocks = math.ceil(num_features / self.context_size)
remainder = num_features % self.context_size
if remainder > 0:
# right padding to reach block size
hidden_states = torch.nn.functional.pad(
hidden_states, (0, 0, 0, self.context_size - remainder))
# NOTE: would be nice to try to use qkvparallellinear
# here for this block attention implementation if possible
query_states = self.to_q(hidden_states)
key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1)
query_states = query_states.reshape(bsz, num_blocks, self.context_size,
self.num_heads,
-1).transpose(2, 3)
key_states = key_states.reshape(bsz, num_blocks, self.context_size,
self.num_heads, -1).transpose(2, 3)
value_states = value_states.reshape(bsz, num_blocks, self.context_size,
self.num_heads,
-1).transpose(2, 3)
# shaw's relative positional embedding
dist = attention_dists.to(hidden_states.device)
rel_pos_emb = self.rel_pos_emb(dist)
rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] +
list(rel_pos_emb.shape))
pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded,
dim=-1) * self.scale
if remainder > 0:
# masked attention in the extended block
mask = torch.ones(self.context_size,
self.context_size,
dtype=bool,
device=hidden_states.device)
mask[:remainder, :remainder] = 0
mask_value = -torch.finfo(pos_attn.dtype).max
pos_attn[:, -1, :].masked_fill_(mask, mask_value)
with torch.nn.attention.sdpa_kernel(
torch.nn.attention.SDPBackend.MATH):
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
attn_mask=pos_attn,
scale=self.scale)
out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1)
return self.to_out(out[:, :num_features, :])
class GraniteSpeechConformerDepthWiseConv1d(nn.Module):
"""Wrapper for padded 1D pointwise convolution."""
def __init__(self,
chan_in: int,
chan_out: int,
kernel_size: int,
prefix: str = ""):
super().__init__()
# Padding for the 1D conv is symmetric or close (i.e., offset by one).
pad = kernel_size // 2
pad_offset = (kernel_size + 1) % 2
self.padding = (pad, pad - pad_offset)
self.conv = nn.Conv1d(chan_in,
chan_out,
kernel_size,
groups=chan_in,
bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = F.pad(hidden_states, self.padding)
return self.conv(hidden_states)
class GraniteSpeechConformerConvModule(nn.Module):
"""Conformer conv module consisting of several 1D/depthwise 1D
convolutional layers.
"""
def __init__(self, config: PretrainedConfig, prefix: str = ""):
super().__init__()
inner_dim = config.hidden_dim * config.conv_expansion_factor
self.norm = nn.LayerNorm(config.hidden_dim)
self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1)
self.glu = nn.GLU(dim=1)
self.depth_conv = GraniteSpeechConformerDepthWiseConv1d(
inner_dim,
inner_dim,
kernel_size=config.conv_kernel_size,
prefix=f"{prefix}.depth_conv",
)
self.silu = nn.SiLU()
self.batch_norm = nn.BatchNorm1d(inner_dim)
self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.norm(hidden_states)
hidden_states = self.up_conv(hidden_states.permute(0, 2, 1))
hidden_states = self.glu(hidden_states)
hidden_states = self.depth_conv(hidden_states)
hidden_states = self.silu(self.batch_norm(hidden_states))
hidden_states = self.down_conv(hidden_states).permute(0, 2, 1)
return hidden_states
class GraniteSpeechConformerBlock(nn.Module):
"""Conformer block, consisting largely of linear layers,
attention, and convolutional layers."""
def __init__(self, config: PretrainedConfig, prefix: str = ""):
super().__init__()
self.ff1 = GraniteSpeechConformerFeedForward(config,
prefix=f"{prefix}.ff1")
self.attn = GraniteSpeechConformerAttention(config,
prefix=f"{prefix}.attn")
self.conv = GraniteSpeechConformerConvModule(config,
prefix=f"{prefix}.conv")
self.ff2 = GraniteSpeechConformerFeedForward(config,
prefix=f"{prefix}.ff2")
self.post_norm = nn.LayerNorm(config.hidden_dim)
def forward(self, hidden_states: torch.Tensor,
attention_dists: torch.Tensor) -> torch.Tensor:
hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states
hidden_states = self.attn(
hidden_states, attention_dists=attention_dists) + hidden_states
hidden_states = self.conv(hidden_states) + hidden_states
hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states
hidden_states = self.post_norm(hidden_states)
return hidden_states
class GraniteSpeechCTCEncoder(nn.Module):
"""CTC Encoder comprising conformer blocks and additional linear layers."""
def __init__(self,
config: PretrainedConfig,
prefix: str,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
# Precompute clamped relative positional encoding distances
seq = torch.arange(config.context_size)
relpos_dist = seq.view(-1, 1) - seq.view(1, -1)
self.attention_dists = torch.clamp(
relpos_dist, -config.context_size,
config.context_size) + config.max_pos_emb
self.input_linear = nn.Linear(config.input_dim,
config.hidden_dim,
bias=True)
self.layers = nn.ModuleList([
GraniteSpeechConformerBlock(
config,
prefix=f"{prefix}.layers.{idx}",
) for idx in range(config.num_layers)
])
self.out = ColumnParallelLinear(
input_size=config.hidden_dim,
output_size=config.output_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out",
)
self.out_mid = RowParallelLinear(
input_size=config.output_dim,
output_size=config.hidden_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_mid",
)
self.softmax = nn.Softmax(dim=-1)
self.num_layers = config.num_layers
def forward(self, hidden_states: torch.Tensor):
hidden_states = self.input_linear(hidden_states)
for idx, layer in enumerate(self.layers, start=1):
hidden_states = layer(hidden_states,
attention_dists=self.attention_dists)
if idx == self.num_layers // 2:
hidden_states_mid = hidden_states.clone()
hidden_states_mid, _ = self.out(hidden_states_mid)
hidden_states_mid = self.softmax(hidden_states_mid)
hidden_states_mid, _ = self.out_mid(hidden_states_mid)
hidden_states += hidden_states_mid
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(
GraniteSpeechMultiModalProcessor,
info=GraniteSpeechMultiModalProcessingInfo,
dummy_inputs=GraniteSpeechDummyInputsBuilder)
class GraniteSpeechForConditionalGeneration(
nn.Module,
SupportsMultiModal,
SupportsPP,
SupportsLoRA,
):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
cache_config = vllm_config.cache_config
self.config = config
self.quant_config = quant_config
self.cache_config = cache_config
self.sampler = get_sampler()
# The language model is typically a Granite LLM
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
# Conformer encoder
self.encoder = GraniteSpeechCTCEncoder(
config=config.encoder_config,
quant_config=quant_config,
prefix=f"{prefix}.encoder",
)
# Blip2 QFormer
self.projector = GraniteSpeechEncoderProjector(
config=config,
quant_config=quant_config,
cache_config=cache_config,
prefix=f"{prefix}.projector",
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _parse_and_validate_audio_input(
self,
**kwargs: object,
) -> Optional[GraniteSpeechAudioInputs]:
input_features = kwargs.pop("input_features", None)
input_features_mask = kwargs.pop("input_features_mask", None)
audio_embed_sizes = kwargs.pop("audio_embed_sizes", None)
if input_features is None:
return None
# If we have a batch of variable feature length audio clips, we need
# to mask the features; usually we would get an input_features_mask
# from the processor, but we handle rebuilding it here since
# vLLM generally processes everything independently + batches.
if input_features_mask is None:
input_features_mask = self._build_input_features_mask(
audio_embed_sizes)
if not isinstance(input_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio input features. "
f"Got type: {type(input_features)}")
if input_features_mask is not None and not isinstance(
input_features_mask, torch.Tensor):
raise ValueError("Incorrect type of audio input features mask. "
f"Got type: {type(input_features_mask)}")
if isinstance(input_features, torch.Tensor):
# Granite speech currently only allows one audio token per instance
# and features are already unsqueezed in the processor, so one
# instance will have shape [1, {num_features}, 160]. As such,
# input features will usually be of shape
# [bsz, 1, num_features, 160], which we squeeze to be 3D here.
if len(input_features.shape) == 4:
input_features = input_features.squeeze(1)
if len(input_features.shape) != 3:
raise ValueError(
"Squeezed input features should be 3D but are of shape "
f"{input_features.shape}")
input_features = input_features.to(
self.encoder.input_linear.weight.dtype)
else:
# Otherwise we have a list of tensors, which are almost certainly
# differing in their respective numbers of audio features;
# stack them into a 3D tensor of size [bsz, most_num_features, 160].
input_features = self._pad_and_stack_input_features(
input_features, ).to(self.encoder.input_linear.weight.dtype)
return GraniteSpeechAudioInputs(
input_features=input_features,
input_features_mask=input_features_mask,
audio_embed_sizes=audio_embed_sizes.flatten().tolist(),
)
def _build_input_features_mask(
self,
audio_embed_sizes: torch.Tensor,
) -> torch.Tensor:
"""Calculate the input features mask, which will generally be used
to mask the the padded features for all entries in the batch except
for those with the most audio features.
Args:
audio_embed_sizes: torch.Tensor
Tensor of num features in each seq in the batch.
Returns:
torch.Tensor: Mask of shape (bsz, num_features) to be applied to
the audio features prior to splitting the audio embeddings.
"""
most_audio_features = torch.max(audio_embed_sizes).item()
mask_indices = torch.arange(
most_audio_features,
device=audio_embed_sizes.device,
).view(1, -1)
input_features_mask = mask_indices < audio_embed_sizes.view(-1, 1)
return input_features_mask
def _pad_and_stack_input_features(
self,
input_features: list[torch.Tensor],
) -> torch.Tensor:
"""Given a list of input features of varying length, pad them to the
same length and stack them into a torch.Tensor.
NOTE: Usually, padding is done in the input processor/feature extractor
and zero padded prior to the computation of the Mel features; the
resulting values are only constant within a batch and generally nonzero
(i.e., slightly negative nums); we should validate that this is okay
since we don't use a feature attention mask, but the more important
thing is that we apply the input_features_mask with variable len
batches.
Args:
input_features: list[torch.Tensor]
Input features to be coerced into a tensor.
Returns:
torch.Tensor: Tensor of shape [bsz, num_features, 160], where
num_features is the max number of features of any entry in the
batch.
"""
# Input features are of shape [bsz, num_features, 160]
feat_lens = [feats.shape[1] for feats in input_features]
padding = [max(feat_lens) - length for length in feat_lens]
# TODO (Alex) - Validate that it's okay to zero pad like this;
# in transformers we zero pad prior to calculating the speech features,
# so the value is not zero and is dependent on the batched features.
padded = [
torch.nn.functional.pad(feats, (0, 0, 0, pad, 0, 0))
for feats, pad in zip(input_features, padding)
]
stacked_features = torch.cat(padded, dim=0).to(input_features[0])
return stacked_features
def _process_audio_input(
self,
audio_input: GraniteSpeechAudioInputs,
) -> tuple[torch.Tensor]:
"""Compute the audio features to be merged into the LLM embeddings.
Args:
audio_input: GraniteSpeechAudioInputs
Audio inputs object containing Mel features, an input features
mask, and the (flattened) number of audio tokens per instance.
Returns:
tuple[torch.Tensor]: List of length bsz.
"""
# TODO (Alex) - support embedding inputs
encoder_embeds = self.encoder(audio_input["input_features"])
# [bsz, <max feature size>, 4096]
projected_embeds = self.projector(encoder_embeds)
# Apply mask on variable length audio features
masked_embeds = projected_embeds[audio_input["input_features_mask"]]
# Split variable length features into a tuple
return torch.split(masked_embeds, audio_input["audio_embed_sizes"])
def get_multimodal_embeddings(
self,
**kwargs: object,
) -> Optional[MultiModalEmbeddings]:
"""Compute the audio embeddings if audio inputs are present."""
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return None
audio_features = self._process_audio_input(audio_input)
return audio_features
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
"""Compute the merged LLM / audio embeddings."""
if multimodal_embeddings is None:
return self.language_model.get_input_embeddings(input_ids)
inputs_embeds = embed_multimodal(
input_ids,
self.config.audio_token_index,
self.language_model.model.get_input_embeddings,
multimodal_embeddings,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
audio_embeds = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds)
input_ids = None
model_output = self.language_model(input_ids, positions,
intermediate_tensors, inputs_embeds)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(
hidden_states,
sampling_metadata,
)
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> Set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_mm_mapping(self) -> MultiModelKeys:
"""Get the module prefix in multimodal models."""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="projector",
tower_model="encoder",
)
......@@ -41,7 +41,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -391,8 +390,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
scale=1 /
self.config.logits_scaling)
self.sampler = get_sampler()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
......@@ -428,14 +425,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
device=device),
})
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
......
......@@ -20,7 +20,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -295,8 +294,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
scale=1 /
self.config.logits_scaling)
self.sampler = get_sampler()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
......@@ -332,14 +329,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
device=device),
})
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
......
......@@ -170,7 +170,8 @@ class GritLMPooler(nn.Module):
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
1)
pooled_data = self.head(mean_embeddings)
pooled_data = self.head(mean_embeddings,
pooling_metadata=pooling_metadata)
pooled_outputs = [
PoolingSequenceGroupOutput(data) for data in pooled_data
......
......@@ -39,7 +39,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
......@@ -521,7 +520,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
config.vocab_size,
self.output_multiplier_scale)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -551,14 +549,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
skip_prefixes = ["rotary_emb.inv_freq"]
......
......@@ -28,7 +28,6 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -603,7 +602,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
if self.config.text_config.tie_word_embeddings:
self.lm_head.weight = self.model.text_model.wte.weight
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
self.sampler = get_sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
......@@ -754,14 +752,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
......
......@@ -13,7 +13,6 @@ from vllm.utils import supports_kw
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -103,14 +102,6 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
"""Return `None` if TP rank > 0."""
...
def sample(
self,
logits: T,
sampling_metadata: "SamplingMetadata",
) -> "SamplerOutput":
"""Only called on TP rank 0."""
...
@overload
def is_text_generation_model(
......
......@@ -23,7 +23,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -336,7 +335,6 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
if self.config.tie_word_embeddings:
self.output.weight = self.model.tok_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -363,14 +361,6 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
......@@ -423,7 +413,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
prefix=prefix,
model_type=model_type)
for attr in ("output", "logits_processor", "sampler"):
for attr in ("output", "logits_processor"):
delattr(self, attr)
config = vllm_config.model_config.hf_config
......
......@@ -8,7 +8,6 @@
# --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
import torch
......@@ -20,7 +19,6 @@ from transformers import BatchEncoding, PretrainedConfig, TensorType
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -698,13 +696,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
(llm_quant_config is not None):
quant_config.modules_to_not_convert.append("vision_model")
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _init_vision_model(
self,
config: PretrainedConfig,
......@@ -903,7 +894,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]:
) -> IntermediateTensors:
if intermediate_tensors is not None:
input_ids = None
......@@ -941,13 +932,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
# unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B
......
......@@ -36,7 +36,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -308,7 +307,6 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
config.mup_width_scale)
self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
scale=self.output_logits_scale)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
......@@ -335,14 +333,6 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
......
......@@ -19,7 +19,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -409,7 +408,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -466,14 +464,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
......
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: E501
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
#
# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
#
# Licensing Information:
# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
# - Other parts of the code are licensed under the MIT License.
#
# Apache License, Version 2.0:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License:
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import copy
import math
from collections.abc import Mapping
from dataclasses import dataclass
from typing import (Any, Iterable, List, Literal, Optional, Sequence, Tuple,
TypedDict, Union)
import torch
from torch import nn
from transformers import BatchFeature
from transformers.activations import GELUActivation
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.moonvit import MoonVitPretrainedModel
from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
from .utils import is_pp_missing_parameter, maybe_prefix
# For dummy input only
@dataclass
class MaxImageTokenMeta:
width: int = 1024
height: int = 1024
class KimiVLMultiModalProjector(nn.Module):
def __init__(self, config: KimiVLConfig):
super().__init__()
self.hidden_size = (config.vision_config.hidden_size *
config.vision_config.merge_kernel_size[0] *
config.vision_config.merge_kernel_size[1])
self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size,
eps=1e-5)
self.linear_1 = nn.Linear(self.hidden_size,
self.hidden_size,
bias=True)
self.act = GELUActivation()
self.linear_2 = nn.Linear(self.hidden_size,
config.text_config.hidden_size,
bias=True)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.pre_norm(image_features).view(
-1, self.hidden_size)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class KimiVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape:`(num_patches, num_channels, patch_size, patch_size)`
"""
image_grid_hws: torch.Tensor
"""Shape:`(num_images, 2)`"""
# TODO: support embeds too
# We only support pixel input for kimi-vl now
KimiVLImageInputs = KimiVLImagePixelInputs
class KimiVLProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(KimiVLConfig)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_processor = self.get_hf_processor()
patch_size = hf_processor.image_processor.patch_size
kernel_size = hf_processor.image_processor.merge_kernel_size
in_token_limit = hf_processor.image_processor.in_token_limit
height = image_height
width = image_width
assert isinstance(height,
int), f"height must be int, current height {height}"
assert isinstance(width,
int), f"width must be int, current width {width}"
assert kernel_size is not None, "kernel_size must be specified"
if (width // patch_size) * (height // patch_size) > in_token_limit:
scale = math.sqrt(in_token_limit / ((width // patch_size) *
(height // patch_size)))
new_w, new_h = int(width * scale), int(height * scale)
width, height = new_w, new_h
kernel_height, kernel_width = kernel_size
pad_height = (kernel_height * patch_size - height %
(kernel_height * patch_size)) % (kernel_height *
patch_size)
pad_width = (kernel_width * patch_size - width %
(kernel_width * patch_size)) % (kernel_width * patch_size)
# Calculate new dimensions after padding and patching
token_height = (height + pad_height) // (kernel_size[0] * patch_size)
token_width = (width + pad_width) // (kernel_size[1] * patch_size)
return int(token_height * token_width)
@property
def image_token_id(self) -> int:
return self.get_hf_config().media_placeholder_token_id
class KimiVLDummyInputsBuilder(BaseDummyInputsBuilder[KimiVLProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
return {
"image":
self._get_dummy_images(width=MaxImageTokenMeta.width,
height=MaxImageTokenMeta.height,
num_images=num_images)
}
class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_grid_hws = hf_inputs.get("image_grid_hws", torch.empty((0, 2)))
image_grid_sizes = image_grid_hws.prod(-1)
# pixel_values is merged as a single large tensor
# image_grid_hws is shapes for each subtensor in pixel_values
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_grid_hws=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
image_token_id = self.info.image_token_id
def get_replacement(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
return [image_token_id] * num_image_tokens
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement,
),
]
@MULTIMODAL_REGISTRY.register_processor(KimiVLMultiModalProcessor,
info=KimiVLProcessingInfo,
dummy_inputs=KimiVLDummyInputsBuilder)
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
model_config = vllm_config.model_config
config: KimiVLConfig = model_config.hf_config
self.config = config
quant_config = vllm_config.quant_config
assert isinstance(config.vision_config, MoonViTConfig)
self.vision_tower = MoonVitPretrainedModel(config.vision_config)
self.multi_modal_projector = KimiVLMultiModalProjector(config=config)
self.quant_config = quant_config
sub_vllm_config = copy.deepcopy(vllm_config)
sub_vllm_config.model_config.hf_config = sub_vllm_config.model_config.hf_config.text_config
self.language_model = DeepseekV2Model(
vllm_config=sub_vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.unpadded_vocab_size = config.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.text_config.hidden_size,
org_num_embeddings=self.config.text_config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.media_placeholder: int = self.config.media_placeholder_token_id
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_world_size = get_tensor_model_parallel_world_size()
# ref: qwen2_vl.py
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return mm_input.reshape(-1, mm_input.shape[-1])
else:
return torch.concat(mm_input)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[KimiVLImageInputs]:
# image input type must be pixel values now
pixel_values = kwargs.pop("pixel_values", None)
image_grid_hws = kwargs.pop("image_grid_hws", None)
if pixel_values is None:
return None
image_grid_hws = self._validate_and_reshape_mm_tensor(
image_grid_hws, "image grid hws")
# pixel_values may have complex shapes
num_channels = 3
patch_size = self.config.vision_config.patch_size
if isinstance(pixel_values, list):
pixel_values = torch.cat([
x.reshape(-1, num_channels, patch_size, patch_size)
for x in pixel_values
])
else:
pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
patch_size)
pixel_values = pixel_values.to(self.vision_tower.dtype)
# image_grid_hws.shape = (N, 2)
assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}"
return KimiVLImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
image_grid_hws=image_grid_hws,
)
# perform vt on processored pixel_values
@torch.inference_mode()
def _process_image_pixels(self,
inputs: KimiVLImagePixelInputs) -> torch.Tensor:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"]
image_grid_hws = inputs["image_grid_hws"]
return self.vision_tower(pixel_values, image_grid_hws)
def _process_image_input(self,
image_input: KimiVLImageInputs) -> torch.Tensor:
assert image_input["type"] == "pixel_values"
image_features = self._process_image_pixels(image_input)
assert isinstance(image_features, list)
lengths = [x.shape[0] for x in image_features]
return self.multi_modal_projector(
torch.cat(image_features)).split(lengths)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs: object) -> Optional[NestedTensors]:
# Validate the multimodal input keyword arguments
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
# Run multimodal inputs through encoder and projector
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
# `get_input_embeddings` should already be implemented for the language
# model as one of the requirements of basic vLLM model implementation.
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=self.config.media_placeholder_token_id)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner from
# `get_multimodal_embeddings` and `get_input_embeddings`, this
# condition is only for v0 compatibility.
elif inputs_embeds is None:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
inputs_embeds = None
else:
inputs_embeds = self.get_input_embeddings(input_ids)
image_embeds = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
image_embeds,
placeholder_token_id=self.config.
media_placeholder_token_id,
)
input_ids = None
hidden_states = self.language_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
**kwargs) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, **kwargs)
return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
config = self.config.text_config
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
if not config.use_mla:
stacked_params_mapping += [
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
if getattr(config, "n_routed_experts", None):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=config.n_routed_experts)
else:
expert_params_mapping = []
params_dict = dict(self.named_parameters())
for args in weights:
name, loaded_weight = args[:2]
kwargs = args[2] if len(args) > 2 else {}
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_tower is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id, **kwargs)
break
else:
for idx, (param_name, weight_name, expert_id,
shard_id) in enumerate(expert_params_mapping):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
expert_id=expert_id,
shard_id=shard_id,
**kwargs)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight, **kwargs)
def get_spec_layer_idx_from_weight_name(config: DeepseekV2Config,
weight_name: str) -> Optional[int]:
if hasattr(config,
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
> 0):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
return layer_idx + i
return None
......@@ -44,7 +44,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
......@@ -139,8 +138,8 @@ class LlamaAttention(nn.Module):
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
# Phi models introduced a partial_rotary_factor parameter in the config
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
1)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
......@@ -172,11 +171,12 @@ class LlamaAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)
if hasattr(config, "interleaved_sliding_window"):
......@@ -346,6 +346,8 @@ class LlamaModel(nn.Module):
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers: tuple[int] = tuple()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
......@@ -372,7 +374,8 @@ class LlamaModel(nn.Module):
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
list[torch.Tensor]]]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
......@@ -384,7 +387,11 @@ class LlamaModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
aux_hidden_states = []
for idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
......@@ -394,6 +401,9 @@ class LlamaModel(nn.Module):
})
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
......@@ -679,11 +689,16 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
else:
self.lm_head = PPMissingLayer()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",
......@@ -715,11 +730,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
......
......@@ -51,8 +51,8 @@ class Llama4MoE(nn.Module):
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
router_scores = torch.sigmoid(router_scores.float()).to(
hidden_states.dtype)
# psuedo-standard is that the router scores are floats
router_scores = torch.sigmoid(router_scores.float())
return (router_scores, router_indices.to(torch.int32))
def __init__(self,
......
......@@ -70,7 +70,7 @@ class LlamaModel(nn.Module):
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.fc(
torch.cat((input_embeds, hidden_states), dim=-1))
......@@ -82,7 +82,8 @@ class LlamaModel(nn.Module):
hidden_states,
residual,
)
return hidden_states + residual
hidden_states = hidden_states + residual
return hidden_states, hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
......@@ -132,7 +133,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
return self.model(input_ids, positions, hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment