Commit 9c4ecf15 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-ori

parents bfc2d6f7 dc1b4a6f
...@@ -20,22 +20,21 @@ from vllm.jsontree import json_map_leaves ...@@ -20,22 +20,21 @@ from vllm.jsontree import json_map_leaves
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
MultiModalFieldConfig, MultiModalFieldConfig,
PromptReplacement, PromptUpdate, PromptReplacement, PromptUpdate,
encode_tokens) PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
class AyaVisionImagePixelInputs(TypedDict): class AyaVisionImagePixelInputs(TypedDict):
...@@ -51,13 +50,6 @@ class AyaVisionImagePixelInputs(TypedDict): ...@@ -51,13 +50,6 @@ class AyaVisionImagePixelInputs(TypedDict):
num_patches: torch.Tensor num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`""" """Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class AyaVisionMultiModalProjector(nn.Module): class AyaVisionMultiModalProjector(nn.Module):
...@@ -125,32 +117,6 @@ class AyaVisionProcessingInfo(BaseProcessingInfo): ...@@ -125,32 +117,6 @@ class AyaVisionProcessingInfo(BaseProcessingInfo):
def get_image_processor(self) -> GotOcr2ImageProcessor: def get_image_processor(self) -> GotOcr2ImageProcessor:
return self.get_hf_processor().image_processor return self.get_hf_processor().image_processor
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_max_image_tokens(self) -> int:
hf_processor = self.get_hf_processor()
image_processor = hf_processor.image_processor
image_size = self.get_image_size_with_most_features()
tokenizer = hf_processor.tokenizer
num_patches = self.get_num_patches(
image_width=image_size.width,
image_height=image_size.height,
size=image_processor.size,
min_patches=image_processor.min_patches,
max_patches=image_processor.max_patches)
image_string = hf_processor._prompt_split_image(num_patches)
x = encode_tokens(
tokenizer,
image_string,
add_special_tokens=False,
)
return len(x)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
...@@ -180,28 +146,29 @@ class AyaVisionProcessingInfo(BaseProcessingInfo): ...@@ -180,28 +146,29 @@ class AyaVisionProcessingInfo(BaseProcessingInfo):
class AyaVisionDummyInputsBuilder( class AyaVisionDummyInputsBuilder(
BaseDummyInputsBuilder[AyaVisionProcessingInfo]): BaseDummyInputsBuilder[AyaVisionProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
self, num_images = mm_counts.get("image", 0)
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
processor = self.info.get_hf_processor() processor = self.info.get_hf_processor()
image_token = processor.image_token image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
image_size = \ image_size = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
mm_data = { return {
"image": "image":
self._get_dummy_images(width=image_size.width, self._get_dummy_images(width=image_size.width,
height=image_size.height, height=image_size.height,
num_images=num_images) num_images=num_images)
} }
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class AyaVisionMultiModalProcessor( class AyaVisionMultiModalProcessor(
...@@ -221,7 +188,6 @@ class AyaVisionMultiModalProcessor( ...@@ -221,7 +188,6 @@ class AyaVisionMultiModalProcessor(
hf_processor = self.info.get_hf_processor(**mm_kwargs) hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_processor = hf_processor.image_processor image_processor = hf_processor.image_processor
hf_config = self.info.get_hf_config()
# HF processor pops the `num_patches` kwarg, which is needed by vLLM # HF processor pops the `num_patches` kwarg, which is needed by vLLM
if (images := if (images :=
mm_data.get("images")) is not None and '<image>' in prompt: mm_data.get("images")) is not None and '<image>' in prompt:
...@@ -234,6 +200,7 @@ class AyaVisionMultiModalProcessor( ...@@ -234,6 +200,7 @@ class AyaVisionMultiModalProcessor(
parsed_images.get_image_size(i) parsed_images.get_image_size(i)
for i in range(len(parsed_images)) for i in range(len(parsed_images))
] ]
num_patches = [ num_patches = [
self.info.get_num_patches( self.info.get_num_patches(
image_width=image_size.width, image_width=image_size.width,
...@@ -243,20 +210,6 @@ class AyaVisionMultiModalProcessor( ...@@ -243,20 +210,6 @@ class AyaVisionMultiModalProcessor(
max_patches=image_processor.max_patches) max_patches=image_processor.max_patches)
for image_size in image_sizes for image_size in image_sizes
] ]
image_tokens_list = [
hf_processor._prompt_split_image(num_patch)
for num_patch in num_patches
]
tokenizer = self.info.get_tokenizer()
image_token_ids = [
tokenizer.encode(image_tokens, add_special_tokens=False)
for image_tokens in image_tokens_list
]
embed_is_patch = [
torch.tensor(image_repl_tokens) == hf_config.image_token_index
for image_repl_tokens in image_token_ids
]
processed_outputs["embed_is_patch"] = embed_is_patch
processed_outputs["num_patches"] = torch.tensor(num_patches) processed_outputs["num_patches"] = torch.tensor(num_patches)
return processed_outputs return processed_outputs
...@@ -271,7 +224,6 @@ class AyaVisionMultiModalProcessor( ...@@ -271,7 +224,6 @@ class AyaVisionMultiModalProcessor(
pixel_values=MultiModalFieldConfig.flat_from_sizes( pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches), "image", num_patches),
num_patches=MultiModalFieldConfig.batched("image"), num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
...@@ -283,6 +235,7 @@ class AyaVisionMultiModalProcessor( ...@@ -283,6 +235,7 @@ class AyaVisionMultiModalProcessor(
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token image_token = hf_processor.image_token
img_patch_token = hf_processor.img_patch_token
image_processor = hf_processor.image_processor image_processor = hf_processor.image_processor
def get_replacement(item_idx: int): def get_replacement(item_idx: int):
...@@ -294,8 +247,11 @@ class AyaVisionMultiModalProcessor( ...@@ -294,8 +247,11 @@ class AyaVisionMultiModalProcessor(
image_height=image_size.height, image_height=image_size.height,
size=image_processor.size, size=image_processor.size,
min_patches=image_processor.min_patches, min_patches=image_processor.min_patches,
max_patches=image_processor.max_patches) max_patches=image_processor.max_patches,
return hf_processor._prompt_split_image(num_patches=num_patches) )
repl = hf_processor._prompt_split_image(num_patches=num_patches)
return PromptUpdateDetails.select_text(repl, img_patch_token)
return [ return [
PromptReplacement( PromptReplacement(
...@@ -424,7 +380,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -424,7 +380,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]: self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
num_patches = kwargs.pop("num_patches", None) num_patches = kwargs.pop("num_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
image_embeds = kwargs.pop("image_embeds", None) image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Aya Vision does not support image_embeds." assert image_embeds is None, "Aya Vision does not support image_embeds."
...@@ -436,30 +391,25 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -436,30 +391,25 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of num_patches. " raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}") f"Got type: {type(num_patches)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
pixel_values = flatten_bn(pixel_values, concat=True) pixel_values = flatten_bn(pixel_values, concat=True)
num_patches = flatten_bn(num_patches, concat=True) num_patches = flatten_bn(num_patches, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return AyaVisionImagePixelInputs( return AyaVisionImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values), pixel_values=self._validate_pixel_values(pixel_values),
num_patches=num_patches, num_patches=num_patches,
embed_is_patch=embed_is_patch,
) )
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input, **kwargs)
return scatter_patch_features( return self._process_image_input(image_input, **kwargs)
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -471,9 +421,9 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -471,9 +421,9 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids, input_ids=input_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
multimodal_embeddings=select_patch_features( multimodal_embeddings=multimodal_embeddings,
multimodal_embeddings), placeholder_token_id=self.config.image_token_index,
placeholder_token_id=self.config.image_token_index) )
return inputs_embeds return inputs_embeds
......
...@@ -18,6 +18,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -18,6 +18,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import ( from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards) MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -94,7 +96,6 @@ class BambaMixerDecoderLayer(nn.Module): ...@@ -94,7 +96,6 @@ class BambaMixerDecoderLayer(nn.Module):
head_dim=config.mamba_d_head, head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps, rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act, activation=config.hidden_act,
chunk_size=config.mamba_chunk_size,
quant_config=quant_config) quant_config=quant_config)
self.feed_forward = BambaMLP(config, quant_config=quant_config) self.feed_forward = BambaMLP(config, quant_config=quant_config)
...@@ -108,7 +109,7 @@ class BambaMixerDecoderLayer(nn.Module): ...@@ -108,7 +109,7 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor] = None, mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
...@@ -119,7 +120,7 @@ class BambaMixerDecoderLayer(nn.Module): ...@@ -119,7 +120,7 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states, residual) hidden_states, residual)
hidden_states = self.mamba(hidden_states, mamba_cache_params, hidden_states = self.mamba(hidden_states, mamba_cache_params,
sequence_idx) mamba2_metadata)
# Fully Connected # Fully Connected
hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual) hidden_states, residual)
...@@ -259,7 +260,7 @@ class BambaModel(nn.Module): ...@@ -259,7 +260,7 @@ class BambaModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config: BambaConfig = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
...@@ -309,20 +310,13 @@ class BambaModel(nn.Module): ...@@ -309,20 +310,13 @@ class BambaModel(nn.Module):
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
seq_idx = None
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) mamba2_metadata = prepare_mamba2_metadata(
for i, (srt, end) in enumerate( chunk_size=self.config.mamba_chunk_size,
zip( input_ids=input_ids,
attn_metadata.query_start_loc, attn_metadata=attn_metadata,
attn_metadata.query_start_loc[1:], )
)):
seq_idx[srt:end] = i
seq_idx.unsqueeze_(0)
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
...@@ -352,7 +346,7 @@ class BambaModel(nn.Module): ...@@ -352,7 +346,7 @@ class BambaModel(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
mamba_cache_params=layer_mamba_cache_params, mamba_cache_params=layer_mamba_cache_params,
sequence_idx=seq_idx, mamba2_metadata=mamba2_metadata,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
...@@ -555,4 +549,4 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -555,4 +549,4 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler, from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
PoolingType) PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -38,19 +39,24 @@ class BertEmbedding(nn.Module): ...@@ -38,19 +39,24 @@ class BertEmbedding(nn.Module):
self.size = config.hidden_size self.size = config.hidden_size
self.word_embeddings = VocabParallelEmbedding(config.vocab_size, self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.position_embeddings = VocabParallelEmbedding(
config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = VocabParallelEmbedding( self.token_type_embeddings = VocabParallelEmbedding(
config.type_vocab_size, config.hidden_size) config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )
self.position_embedding_type = config.position_embedding_type self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute": if self.position_embedding_type == "absolute":
raise ValueError("Only 'absolute' position_embedding_type" + self.position_embeddings = VocabParallelEmbedding(
" is supported") config.max_position_embeddings, config.hidden_size)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )
elif self.position_embedding_type == "rotary":
self.position_embeddings = None
self.position_ids = None
else:
raise ValueError("Only 'absolute' and 'rotary' " +
"position_embedding_type is supported")
def forward( def forward(
self, self,
...@@ -64,9 +70,6 @@ class BertEmbedding(nn.Module): ...@@ -64,9 +70,6 @@ class BertEmbedding(nn.Module):
# Input embeddings. # Input embeddings.
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, token_type_ids = torch.zeros(input_shape,
dtype=torch.long, dtype=torch.long,
...@@ -74,7 +77,12 @@ class BertEmbedding(nn.Module): ...@@ -74,7 +77,12 @@ class BertEmbedding(nn.Module):
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings + position_embeddings embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings) embeddings = self.LayerNorm(embeddings)
return embeddings return embeddings
...@@ -98,7 +106,10 @@ class BertPooler(nn.Module): ...@@ -98,7 +106,10 @@ class BertPooler(nn.Module):
@support_torch_compile @support_torch_compile
class BertEncoder(nn.Module): class BertEncoder(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""): def __init__(self,
vllm_config: VllmConfig,
rotary_kwargs: Optional[dict] = None,
prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
...@@ -107,16 +118,18 @@ class BertEncoder(nn.Module): ...@@ -107,16 +118,18 @@ class BertEncoder(nn.Module):
BertLayer(config=config, BertLayer(config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.layer.{layer_idx}") prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
]) ])
def forward( def forward(
self, self,
positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
for layer in self.layer: for layer in self.layer:
hidden_states = layer(hidden_states) hidden_states = layer(positions, hidden_states)
return hidden_states return hidden_states
...@@ -126,6 +139,7 @@ class BertLayer(nn.Module): ...@@ -126,6 +139,7 @@ class BertLayer(nn.Module):
config: BertConfig, config: BertConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
rotary_kwargs: Optional[dict] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
...@@ -135,6 +149,7 @@ class BertLayer(nn.Module): ...@@ -135,6 +149,7 @@ class BertLayer(nn.Module):
layer_norm_eps=config.layer_norm_eps, layer_norm_eps=config.layer_norm_eps,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.attention") prefix=f"{prefix}.attention")
self.intermediate = BertIntermediate( self.intermediate = BertIntermediate(
...@@ -150,8 +165,8 @@ class BertLayer(nn.Module): ...@@ -150,8 +165,8 @@ class BertLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.output") prefix=f"{prefix}.output")
def forward(self, hidden_states: torch.Tensor): def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor):
attn_output = self.attention(hidden_states) attn_output = self.attention(positions, hidden_states)
intermediate_output = self.intermediate(attn_output) intermediate_output = self.intermediate(attn_output)
output = self.output(intermediate_output, attn_output) output = self.output(intermediate_output, attn_output)
return output return output
...@@ -166,6 +181,7 @@ class BertAttention(nn.Module): ...@@ -166,6 +181,7 @@ class BertAttention(nn.Module):
layer_norm_eps: float, layer_norm_eps: float,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
rotary_kwargs: Optional[dict] = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -174,6 +190,7 @@ class BertAttention(nn.Module): ...@@ -174,6 +190,7 @@ class BertAttention(nn.Module):
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.output") prefix=f"{prefix}.output")
self.output = BertSelfOutput(hidden_size=hidden_size, self.output = BertSelfOutput(hidden_size=hidden_size,
...@@ -183,9 +200,10 @@ class BertAttention(nn.Module): ...@@ -183,9 +200,10 @@ class BertAttention(nn.Module):
def forward( def forward(
self, self,
positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
self_output = self.self(hidden_states) self_output = self.self(positions, hidden_states)
return self.output(self_output, hidden_states) return self.output(self_output, hidden_states)
...@@ -197,6 +215,7 @@ class BertSelfAttention(nn.Module): ...@@ -197,6 +215,7 @@ class BertSelfAttention(nn.Module):
num_attention_heads: int, num_attention_heads: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
rotary_kwargs: Optional[dict] = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -225,6 +244,11 @@ class BertSelfAttention(nn.Module): ...@@ -225,6 +244,11 @@ class BertSelfAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj") prefix=f"{prefix}.qkv_proj")
if rotary_kwargs:
self.rotary_emb = get_rope(**rotary_kwargs)
else:
self.rotary_emb = None
self.attn = Attention(num_heads=self.num_heads, self.attn = Attention(num_heads=self.num_heads,
head_size=self.head_dim, head_size=self.head_dim,
scale=self.scaling, scale=self.scaling,
...@@ -236,10 +260,15 @@ class BertSelfAttention(nn.Module): ...@@ -236,10 +260,15 @@ class BertSelfAttention(nn.Module):
def forward( def forward(
self, self,
positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.rotary_emb:
q, k = self.rotary_emb(positions, q, k)
output = self.attn(q, k, v) output = self.attn(q, k, v)
return output return output
...@@ -321,11 +350,13 @@ class BertModel(nn.Module, SupportsQuant): ...@@ -321,11 +350,13 @@ class BertModel(nn.Module, SupportsQuant):
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
embedding_class: type = BertEmbedding, embedding_class: type = BertEmbedding,
rotary_kwargs: Optional[dict] = None,
add_pooling_layer: bool = False): add_pooling_layer: bool = False):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.embeddings = embedding_class(config) self.embeddings = embedding_class(config)
self.encoder = BertEncoder(vllm_config=vllm_config, self.encoder = BertEncoder(vllm_config=vllm_config,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.encoder") prefix=f"{prefix}.encoder")
self.pooler = BertPooler(config) if add_pooling_layer else None self.pooler = BertPooler(config) if add_pooling_layer else None
...@@ -347,7 +378,7 @@ class BertModel(nn.Module, SupportsQuant): ...@@ -347,7 +378,7 @@ class BertModel(nn.Module, SupportsQuant):
seq_lens=attn_metadata.seq_lens_tensor, seq_lens=attn_metadata.seq_lens_tensor,
position_ids=position_ids, position_ids=position_ids,
token_type_ids=token_type_ids) token_type_ids=token_type_ids)
return self.encoder(hidden_states) return self.encoder(position_ids, hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
...@@ -401,6 +432,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): ...@@ -401,6 +432,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
self.config = vllm_config.model_config.hf_config
self.model = self._build_model(vllm_config=vllm_config, self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self._pooler = self._build_pooler(pooler_config) self._pooler = self._build_pooler(pooler_config)
......
...@@ -15,12 +15,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -15,12 +15,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets, BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptUpdate) PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .blip import BlipVisionModel from .blip import BlipVisionModel
...@@ -406,13 +407,6 @@ class Blip2ProcessingInfo(BaseProcessingInfo): ...@@ -406,13 +407,6 @@ class Blip2ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1} return {"image": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int: def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
return hf_config.num_query_tokens return hf_config.num_query_tokens
...@@ -420,29 +414,27 @@ class Blip2ProcessingInfo(BaseProcessingInfo): ...@@ -420,29 +414,27 @@ class Blip2ProcessingInfo(BaseProcessingInfo):
class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]): class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
max_image_size = vision_config.image_size max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
mm_data = { return {
"image": "image":
self._get_dummy_images(width=max_image_size, self._get_dummy_images(width=max_image_size,
height=max_image_size, height=max_image_size,
num_images=num_images) num_images=num_images)
} }
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
...@@ -627,6 +619,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -627,6 +619,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return self.language_projection(query_output) return self.language_projection(query_output)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -30,12 +30,13 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -30,12 +30,13 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
...@@ -64,13 +65,6 @@ class ChameleonProcessingInfo(BaseProcessingInfo): ...@@ -64,13 +65,6 @@ class ChameleonProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1} return {"image": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int: def get_num_image_tokens(self) -> int:
processor = self.get_hf_processor() processor = self.get_hf_processor()
return processor.image_seq_length return processor.image_seq_length
...@@ -79,28 +73,31 @@ class ChameleonProcessingInfo(BaseProcessingInfo): ...@@ -79,28 +73,31 @@ class ChameleonProcessingInfo(BaseProcessingInfo):
class ChameleonDummyInputsBuilder( class ChameleonDummyInputsBuilder(
BaseDummyInputsBuilder[ChameleonProcessingInfo]): BaseDummyInputsBuilder[ChameleonProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
config = self.info.get_hf_config() config = self.info.get_hf_config()
width = height = config.vq_config.resolution width = height = config.vq_config.resolution
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
mm_data = { return {
"image": "image":
self._get_dummy_images(width=width, self._get_dummy_images(width=width,
height=height, height=height,
num_images=num_images) num_images=num_images)
} }
return ProcessorInputs(
prompt_text="<image>" * num_images,
mm_data=mm_data,
)
class ChameleonMultiModalProcessor( class ChameleonMultiModalProcessor(
BaseMultiModalProcessor[ChameleonProcessingInfo]): BaseMultiModalProcessor[ChameleonProcessingInfo]):
...@@ -162,9 +159,9 @@ class ChameleonMultiModalProcessor( ...@@ -162,9 +159,9 @@ class ChameleonMultiModalProcessor(
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target=[image_token_id], target=[image_token_id],
replacement=PromptUpdateDetails( replacement=PromptUpdateDetails.select_token_id(
full=([image_start_id] + image_tokens + [image_end_id]), [image_start_id] + image_tokens + [image_end_id],
features=image_tokens, embed_token_id=image_token_id,
), ),
) )
] ]
...@@ -988,6 +985,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -988,6 +985,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(pixel_values),
) )
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -30,9 +30,6 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): ...@@ -30,9 +30,6 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
) -> int: ) -> int:
return self.get_patch_grid_length()**2 + 1 return self.get_patch_grid_length()**2 + 1
def get_max_image_tokens(self) -> int:
return self.get_patch_grid_length()**2 + 1
def get_image_size(self) -> int: def get_image_size(self) -> int:
return self.vision_config.image_size return self.vision_config.image_size
......
...@@ -51,7 +51,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -51,7 +51,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (extract_layer_index, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -385,6 +386,56 @@ class DeepseekModel(nn.Module): ...@@ -385,6 +386,56 @@ class DeepseekModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class DeepseekForCausalLM(nn.Module, SupportsPP): class DeepseekForCausalLM(nn.Module, SupportsPP):
...@@ -439,50 +490,5 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): ...@@ -439,50 +490,5 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ loader = AutoWeightsLoader(self)
# (param_name, shard_name, shard_id) return loader.load_weights(weights)
("qkv_proj", "q_proj", "q"), \ No newline at end of file
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -163,14 +163,16 @@ class DeepseekV2MoE(nn.Module): ...@@ -163,14 +163,16 @@ class DeepseekV2MoE(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor router_logits=router_logits) * self.routed_scaling_factor
else: else:
# This is a special case to avoid FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
if shared_output is not None: if shared_output is not None:
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
else: else:
# This is a special case to avoid FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor) * (1. / self.routed_scaling_factor)
if self.tp_size > 1: if self.tp_size > 1:
...@@ -502,6 +504,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -502,6 +504,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# DecoderLayers are created with `make_layers` which passes the prefix # DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index. # with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1]) layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx
if model_config.use_mla: if model_config.use_mla:
attn_cls = DeepseekV2MLAAttention attn_cls = DeepseekV2MLAAttention
else: else:
...@@ -564,19 +567,30 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -564,19 +567,30 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
) )
# Fully Connected if hidden_states.dtype == torch.float16:
if isinstance(self.mlp, DeepseekV2MoE) and \ # Fix FP16 overflow
hidden_states.dtype == torch.float16: # We scale both hidden_states and residual before
# This is a special case to avoid FP16 overflow # rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
# Fully Connected
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp, DeepseekV2MLP) and \
hidden_states.dtype == torch.float16: if isinstance(self.mlp,
# This is a special case to avoid FP16 overflow DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor hidden_states *= 1. / self.routed_scaling_factor
residual *= 1. / self.routed_scaling_factor
return hidden_states, residual return hidden_states, residual
......
...@@ -19,14 +19,14 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -19,14 +19,14 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors) MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
MlpProjectorConfig, MlpProjectorConfig,
...@@ -168,47 +168,34 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo): ...@@ -168,47 +168,34 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
image_width=x[1], image_height=x[0])) image_width=x[1], image_height=x[0]))
return ImageSize(width=width, height=height) return ImageSize(width=width, height=height)
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
num_images = mm_counts.get("image", 0)
max_image_size = self.get_image_size_with_most_features()
max_image_tokens = self.get_num_image_tokens(
image_height=max_image_size.height,
image_width=max_image_size.width,
cropping=num_images <= 2)
return {"image": max_image_tokens}
class DeepseekVL2DummyInputsBuilder( class DeepseekVL2DummyInputsBuilder(
BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]): BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token
max_image_size = self.info.get_image_size_with_most_features() max_image_size = self.info.get_image_size_with_most_features()
mm_data = { return {
"image": "image":
self._get_dummy_images(width=max_image_size.width, self._get_dummy_images(width=max_image_size.width,
height=max_image_size.height, height=max_image_size.height,
num_images=num_images) num_images=num_images)
} }
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class DeepseekVL2MultiModalProcessor( class DeepseekVL2MultiModalProcessor(
BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]): BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]):
...@@ -604,6 +591,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -604,6 +591,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return self._pixel_values_to_embedding( return self._pixel_values_to_embedding(
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop) pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import BatchFeature, PretrainedConfig from transformers import BartTokenizer, BatchFeature, PretrainedConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -21,13 +21,14 @@ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, ...@@ -21,13 +21,14 @@ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
BartScaledWordEmbedding) BartScaledWordEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems MultiModalKwargs)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseProcessingInfo, from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor, EncDecMultiModalProcessor,
PromptIndexTargets, PromptInsertion, PromptIndexTargets, PromptInsertion,
PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
...@@ -764,42 +765,33 @@ class Florence2ProcessingInfo(BaseProcessingInfo): ...@@ -764,42 +765,33 @@ class Florence2ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1} return {"image": 1}
def get_max_image_tokens(self) -> int: def get_num_image_tokens(self) -> int:
processor_config = self.ctx.get_hf_image_processor_config() processor_config = self.ctx.get_hf_image_processor_config()
return processor_config["image_seq_length"] return processor_config["image_seq_length"]
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
class Florence2DummyInputsBuilder( class Florence2DummyInputsBuilder(
BaseDummyInputsBuilder[Florence2ProcessingInfo]): BaseDummyInputsBuilder[Florence2ProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
target_width = target_height = self.info.get_hf_config().projection_dim target_width = target_height = self.info.get_hf_config().projection_dim
mm_data = { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images)
} }
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
class Florence2MultiModalProcessor( class Florence2MultiModalProcessor(
EncDecMultiModalProcessor[Florence2ProcessingInfo]): EncDecMultiModalProcessor[Florence2ProcessingInfo]):
...@@ -826,6 +818,18 @@ class Florence2MultiModalProcessor( ...@@ -826,6 +818,18 @@ class Florence2MultiModalProcessor(
) -> Union[str, list[int]]: ) -> Union[str, list[int]]:
return [self.info.get_hf_config().eos_token_id] return [self.info.get_hf_config().eos_token_id]
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
hf_processor = self.info.get_hf_processor()
tokenizer: BartTokenizer = hf_processor.tokenizer
prompt_text = tokenizer.decode(prompt_tokens)
# convert task tokens to prompt
prompt_text = hf_processor._construct_prompts([prompt_text])[0]
prompt_tokens = tokenizer.encode(prompt_text, add_special_tokens=False)
return prompt_tokens
def _call_hf_processor( def _call_hf_processor(
self, self,
prompt: str, prompt: str,
...@@ -859,7 +863,7 @@ class Florence2MultiModalProcessor( ...@@ -859,7 +863,7 @@ class Florence2MultiModalProcessor(
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
pad_token_id = hf_config.pad_token_id pad_token_id = hf_config.pad_token_id
num_image_tokens = self.info.get_max_image_tokens() num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [pad_token_id] * num_image_tokens image_tokens = [pad_token_id] * num_image_tokens
return [ return [
...@@ -1038,6 +1042,9 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1038,6 +1042,9 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values = image_input["data"] pixel_values = image_input["data"]
return self._encode_image(pixel_values) return self._encode_image(pixel_values)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
""" PyTorch Fuyu model.""" """ PyTorch Fuyu model."""
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict, Union from typing import Literal, Optional, Set, Tuple, TypedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -31,19 +31,19 @@ from vllm.model_executor.layers.sampler import SamplerOutput ...@@ -31,19 +31,19 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
# Cannot find the following 2 numbers from hf config. # Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011 _IMAGE_TOKEN_ID = 71011
...@@ -66,14 +66,6 @@ class FuyuImagePatchInputs(TypedDict): ...@@ -66,14 +66,6 @@ class FuyuImagePatchInputs(TypedDict):
flattened just like `flat_data`. flattened just like `flat_data`.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class FuyuProcessingInfo(BaseProcessingInfo): class FuyuProcessingInfo(BaseProcessingInfo):
...@@ -89,21 +81,6 @@ class FuyuProcessingInfo(BaseProcessingInfo): ...@@ -89,21 +81,6 @@ class FuyuProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1} return {"image": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_ncols, max_nrows = self.get_image_feature_grid_size(
image_width=target_width,
image_height=target_height,
)
max_image_tokens = (max_ncols + 1) * max_nrows
return {"image": max_image_tokens}
def get_image_feature_grid_size( def get_image_feature_grid_size(
self, self,
*, *,
...@@ -128,6 +105,19 @@ class FuyuProcessingInfo(BaseProcessingInfo): ...@@ -128,6 +105,19 @@ class FuyuProcessingInfo(BaseProcessingInfo):
nrows = math.ceil(image_height / patch_height) nrows = math.ceil(image_height / patch_height)
return ncols, nrows return ncols, nrows
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
ncols, nrows = self.get_image_feature_grid_size(
image_width=image_width,
image_height=image_height,
)
return ncols * nrows
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor() image_processor = self.get_image_processor()
return ImageSize(width=image_processor.size["width"], return ImageSize(width=image_processor.size["width"],
...@@ -136,27 +126,25 @@ class FuyuProcessingInfo(BaseProcessingInfo): ...@@ -136,27 +126,25 @@ class FuyuProcessingInfo(BaseProcessingInfo):
class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]): class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
mm_data = { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images)
} }
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
...@@ -192,19 +180,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): ...@@ -192,19 +180,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
processed_outputs["image_patches"] = image_patches[0] processed_outputs["image_patches"] = image_patches[0]
# get patch grid size for each image
embed_is_patch = []
for image in images:
ncols, nrows = self.info.get_image_feature_grid_size(
image_width=image.width,
image_height=image.height,
)
mask = torch.tensor(([True] * ncols + [False]) * nrows)
embed_is_patch.append(mask)
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs return processed_outputs
def _apply_hf_processor_tokens_only( def _apply_hf_processor_tokens_only(
...@@ -224,8 +199,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): ...@@ -224,8 +199,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict(image_patches=MultiModalFieldConfig.batched("image"), return dict(image_patches=MultiModalFieldConfig.batched("image"))
embed_is_patch=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates( def _get_prompt_updates(
self, self,
...@@ -252,9 +226,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): ...@@ -252,9 +226,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
image_tokens = ([_IMAGE_TOKEN_ID] * ncols + image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows [_NEWLINE_TOKEN_ID]) * nrows
return PromptUpdateDetails( return PromptUpdateDetails.select_token_id(
full=image_tokens + [bos_token_id], image_tokens + [bos_token_id],
features=image_tokens, embed_token_id=_IMAGE_TOKEN_ID,
) )
return [ return [
...@@ -329,20 +303,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -329,20 +303,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of image patches. " raise ValueError("Incorrect type of image patches. "
f"Got type: {type(image_patches)}") f"Got type: {type(image_patches)}")
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
image_patches_flat = flatten_bn(image_patches) image_patches_flat = flatten_bn(image_patches)
embed_is_patch = flatten_bn(embed_is_patch)
return FuyuImagePatchInputs( return FuyuImagePatchInputs(
type="image_patches", type="image_patches",
flat_data=self._validate_pixel_values( flat_data=self._validate_pixel_values(
flatten_bn(image_patches_flat, concat=True)), flatten_bn(image_patches_flat, concat=True)),
patches_per_image=[x.size(0) for x in image_patches_flat], patches_per_image=[x.size(0) for x in image_patches_flat],
embed_is_patch=embed_is_patch,
) )
return None return None
...@@ -358,18 +325,16 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -358,18 +325,16 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return vision_embeddings_flat.split(patches_per_image, dim=0) return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) return self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -379,8 +344,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -379,8 +344,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, input_ids,
select_patch_features(multimodal_embeddings), _IMAGE_TOKEN_ID) inputs_embeds,
multimodal_embeddings,
_IMAGE_TOKEN_ID,
)
return inputs_embeds return inputs_embeds
def forward( def forward(
......
...@@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -319,6 +319,46 @@ class GemmaModel(nn.Module): ...@@ -319,6 +319,46 @@ class GemmaModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
...@@ -385,44 +425,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -385,44 +425,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ loader = AutoWeightsLoader(
# (param_name, shard_name, shard_id) self,
("qkv_proj", "q_proj", "q"), skip_prefixes=(["lm_head."]
("qkv_proj", "k_proj", "k"), if self.config.tie_word_embeddings else None),
("qkv_proj", "v_proj", "v"), )
("gate_up_proj", "gate_proj", 0), return loader.load_weights(weights)
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -15,8 +15,9 @@ from vllm.model_executor.layers.layernorm import GemmaRMSNorm ...@@ -15,8 +15,9 @@ from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
# yapf: disable # yapf: disable
...@@ -25,10 +26,10 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -25,10 +26,10 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
PromptReplacement, PromptTargetMatch, PromptReplacement, PromptTargetMatch,
PromptUpdate, PromptUpdateDetails, PromptUpdate, PromptUpdateDetails,
encode_tokens, find_mm_placeholders, find_mm_placeholders,
replace_token_matches) replace_token_matches)
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
...@@ -36,7 +37,6 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, ...@@ -36,7 +37,6 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -54,14 +54,6 @@ class Gemma3ImagePixelInputs(TypedDict): ...@@ -54,14 +54,6 @@ class Gemma3ImagePixelInputs(TypedDict):
num_patches: torch.Tensor num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`""" """Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
Gemma3ImageInputs = Gemma3ImagePixelInputs Gemma3ImageInputs = Gemma3ImagePixelInputs
...@@ -77,13 +69,6 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): ...@@ -77,13 +69,6 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def _resolve_image_kwargs( def _resolve_image_kwargs(
self, self,
processor: Gemma3Processor, processor: Gemma3Processor,
...@@ -183,7 +168,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): ...@@ -183,7 +168,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
if processor is None: if processor is None:
processor = self.get_hf_processor() processor = self.get_hf_processor()
image_token = processor.boi_token boi_token = processor.boi_token
num_crops = self.get_num_crops( num_crops = self.get_num_crops(
image_width=image_width, image_width=image_width,
...@@ -192,19 +177,21 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): ...@@ -192,19 +177,21 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
) )
if num_crops == 0: if num_crops == 0:
image_text = image_token image_text = boi_token
else: else:
crops_image_tokens = " ".join(image_token crops_image_tokens = " ".join(boi_token for _ in range(num_crops))
for _ in range(num_crops))
image_text = ( image_text = (
f"Here is the original image {image_token} and here are some " f"Here is the original image {boi_token} and here are some "
f"crops to help you see better {crops_image_tokens}") f"crops to help you see better {crops_image_tokens}")
repl_full = image_text.replace(image_token, repl_full = image_text.replace(boi_token,
processor.full_image_sequence) processor.full_image_sequence)
repl_features = repl_full.strip("\n")
return PromptUpdateDetails(full=repl_full, features=repl_features) tokenizer = processor.tokenizer
vocab = tokenizer.get_vocab()
image_token_id = vocab[tokenizer.image_token]
return PromptUpdateDetails.select_token_id(repl_full, image_token_id)
def get_num_image_tokens( def get_num_image_tokens(
self, self,
...@@ -213,19 +200,17 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): ...@@ -213,19 +200,17 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
image_height: int, image_height: int,
processor: Optional[Gemma3Processor], processor: Optional[Gemma3Processor],
) -> int: ) -> int:
tokenizer = self.get_tokenizer() if processor is None:
image_repl = self.get_image_repl( processor = self.get_hf_processor()
num_crops = self.get_num_crops(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
processor=processor, processor=processor,
) )
image_seq_len = processor.image_seq_length
image_repl_tokens = encode_tokens( return (num_crops + 1) * image_seq_len
tokenizer,
image_repl.features,
add_special_tokens=False,
)
return len(image_repl_tokens)
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor() processor = self.get_hf_processor()
...@@ -237,43 +222,34 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): ...@@ -237,43 +222,34 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
# Result in the max possible feature size (h:w = max_num_crops:1) # Result in the max possible feature size (h:w = max_num_crops:1)
return ImageSize(height=50 * max_num_crops, width=50) return ImageSize(height=50 * max_num_crops, width=50)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens( class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
image_width=target_width,
image_height=target_height,
processor=None,
)
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): processor = self.info.get_hf_processor()
image_token = processor.boi_token
return image_token * num_images
def get_dummy_processor_inputs( def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
processor = self.info.get_hf_processor()
image_token = processor.boi_token
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
mm_data = { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images)
} }
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
...@@ -301,28 +277,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ...@@ -301,28 +277,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
] ]
hf_processor = self.info.get_hf_processor(**mm_kwargs) hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_repl_features = [
self.info.get_image_repl(image_width=size.width,
image_height=size.height,
processor=hf_processor).features
for size in image_sizes
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
vocab = tokenizer.get_vocab()
image_token_id = vocab[tokenizer.image_token]
embed_is_patch = [
torch.tensor(image_repl_tokens) == image_token_id
for image_repl_tokens in image_repls_feature_tokens
]
processed_outputs["embed_is_patch"] = embed_is_patch
num_crops = [ num_crops = [
self.info.get_num_crops(image_width=size.width, self.info.get_num_crops(image_width=size.width,
image_height=size.height, image_height=size.height,
...@@ -344,7 +298,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ...@@ -344,7 +298,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
pixel_values=MultiModalFieldConfig.flat_from_sizes( pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops + 1), "image", num_crops + 1),
num_crops=MultiModalFieldConfig.batched("image"), num_crops=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_updates( def _get_prompt_updates(
...@@ -454,6 +407,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ...@@ -454,6 +407,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
item_idx=p.item_idx, item_idx=p.item_idx,
start_idx=repl_orig_idxs[p.start_idx], start_idx=repl_orig_idxs[p.start_idx],
tokens=p.tokens, tokens=p.tokens,
is_embed=p.is_embed,
) for p in placeholders ) for p in placeholders
] ]
for modality, placeholders in repls.items() for modality, placeholders in repls.items()
...@@ -572,7 +526,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -572,7 +526,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self, **kwargs: object) -> Optional[Gemma3ImageInputs]: self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
num_crops = kwargs.pop("num_crops", None) num_crops = kwargs.pop("num_crops", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
image_embeds = kwargs.pop("image_embeds", None) image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds." assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None: if pixel_values is None:
...@@ -586,19 +539,13 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -586,19 +539,13 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
raise ValueError("Incorrect type of num_crops. " raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}") f"Got type: {type(num_crops)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
pixel_values = flatten_bn(pixel_values, concat=True) pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True) num_crops = flatten_bn(num_crops, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return Gemma3ImagePixelInputs( return Gemma3ImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values), pixel_values=self._validate_pixel_values(pixel_values),
num_patches=num_crops + 1, num_patches=num_crops + 1,
embed_is_patch=embed_is_patch,
) )
def _image_pixels_to_features( def _image_pixels_to_features(
...@@ -629,18 +576,16 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -629,18 +576,16 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist()) e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
] ]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) return self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -652,7 +597,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -652,7 +597,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
self.config.image_token_index, self.config.image_token_index,
) )
return inputs_embeds return inputs_embeds
......
# SPDX-License-Identifier: Apache-2.0
# Copyright 2025 The Zhipu AI team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4-0414 model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import Glm4Config
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .llama import LlamaMLP as Glm4MLP
from .llama import LlamaModel
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
class Glm4Attention(nn.Module):
def __init__(self,
config: Glm4Config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: Optional[int] = None,
qkv_bias: bool = False,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position,
base=self.rope_theta,
rope_scaling=rope_scaling,
partial_rotary_factor=partial_rotary_factor,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=attn_type)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Glm4DecoderLayer(nn.Module):
def __init__(
self,
config: Glm4Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
self.self_attn = Glm4Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
qkv_bias=getattr(config, 'attention_bias', False),
head_dim=getattr(config, 'head_dim', None),
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=AttentionType.DECODER,
)
self.mlp = Glm4MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_self_attn_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_mlp_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states = self.post_self_attn_layernorm(hidden_states)
hidden_states = residual + hidden_states
# Fully Connected
hidden_states = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_mlp_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, residual
ALL_DECODER_LAYER_TYPES = {
"attention": Glm4DecoderLayer,
}
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class Glm4Model(LlamaModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=Glm4DecoderLayer)
class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Glm4Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
...@@ -12,7 +12,7 @@ from torch import nn ...@@ -12,7 +12,7 @@ from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from torchvision import transforms from torchvision import transforms
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from transformers import PreTrainedTokenizer, TensorType from transformers import BatchFeature, PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
...@@ -28,13 +28,13 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -28,13 +28,13 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BatchFeature, BaseProcessingInfo, PromptReplacement,
MultiModalFieldConfig, PromptUpdate)
PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
...@@ -431,13 +431,6 @@ class GLM4VProcessingInfo(BaseProcessingInfo): ...@@ -431,13 +431,6 @@ class GLM4VProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1} return {"image": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_feature_tokens()}
def get_num_image_tokens(self) -> int: def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
...@@ -454,31 +447,31 @@ class GLM4VProcessingInfo(BaseProcessingInfo): ...@@ -454,31 +447,31 @@ class GLM4VProcessingInfo(BaseProcessingInfo):
class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
return base_text * num_images
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
target_width = target_height = vision_config["image_size"] target_width = target_height = vision_config["image_size"]
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
mm_data = { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images)
} }
base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
return ProcessorInputs(
prompt_text=base_text * num_images,
mm_data=mm_data,
)
class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
...@@ -596,6 +589,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, ...@@ -596,6 +589,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
return self.transformer.vision(pixel_values) return self.transformer.vision(pixel_values)
def get_language_model(self) -> torch.nn.Module:
return self.transformer
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -50,8 +50,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -50,8 +50,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, make_layers, from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
maybe_prefix) make_layers, maybe_prefix)
class GraniteMLP(nn.Module): class GraniteMLP(nn.Module):
...@@ -260,6 +260,7 @@ class GraniteModel(nn.Module): ...@@ -260,6 +260,7 @@ class GraniteModel(nn.Module):
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config self.config = config
self.quant_config = quant_config
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0 (lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab self.vocab_size = config.vocab_size + lora_vocab
...@@ -321,6 +322,65 @@ class GraniteModel(nn.Module): ...@@ -321,6 +322,65 @@ class GraniteModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
...@@ -428,71 +488,18 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -428,71 +488,18 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ skip_prefixes = [
# (param_name, shard_name, shard_id) "rotary_emb.inv_freq",
(".qkv_proj", ".q_proj", "q"), # Models trained using ColossalAI may include these tensors in
(".qkv_proj", ".k_proj", "k"), # the checkpoint. Skip them.
(".qkv_proj", ".v_proj", "v"), "rotary_emb.cos_cached",
(".gate_up_proj", ".gate_proj", 0), "rotary_emb.sin_cached",
(".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) # With tie_word_embeddings, we can skip lm_head.weight
loaded_params: Set[str] = set() # The weight might appear unnecessarily in the files if the model is
for name, loaded_weight in weights: # processed with quantization, LoRA, fine-tuning, etc.
if "rotary_emb.inv_freq" in name: if self.config.tie_word_embeddings:
continue skip_prefixes.append("lm_head.weight")
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name): loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
# Models trained using ColossalAI may include these tensors in return loader.load_weights(weights)
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors ...@@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors
from . import mixtral from . import mixtral
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import make_layers, maybe_prefix from .utils import AutoWeightsLoader, make_layers, maybe_prefix
class GraniteMoeMoE(nn.Module): class GraniteMoeMoE(nn.Module):
...@@ -252,6 +252,8 @@ class GraniteMoeModel(nn.Module): ...@@ -252,6 +252,8 @@ class GraniteMoeModel(nn.Module):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config # Required by MixtralModel
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0 (lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab self.vocab_size = config.vocab_size + lora_vocab
...@@ -304,6 +306,40 @@ class GraniteMoeModel(nn.Module): ...@@ -304,6 +306,40 @@ class GraniteMoeModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
new_weights = {}
for n, p in weights:
if n.endswith('.block_sparse_moe.input_linear.weight'):
for e in range(p.size(0)):
w1_name = n.replace(
'.block_sparse_moe.input_linear.weight',
f".block_sparse_moe.experts.{e}.w1.weight")
w3_name = n.replace(
'.block_sparse_moe.input_linear.weight',
f".block_sparse_moe.experts.{e}.w3.weight")
w1_param, w3_param = p[e].chunk(2, dim=0)
assert w1_name not in new_weights
assert w3_name not in new_weights
new_weights[w1_name] = w1_param
new_weights[w3_name] = w3_param
elif n.endswith('.block_sparse_moe.output_linear.weight'):
for e in range(p.size(0)):
w2_name = n.replace(
'.block_sparse_moe.output_linear.weight',
f".block_sparse_moe.experts.{e}.w2.weight")
w2_param = p[e]
assert w2_name not in new_weights
new_weights[w2_name] = w2_param
elif n.endswith('.block_sparse_moe.router.layer.weight'):
gate_name = n.replace('.block_sparse_moe.router.layer.weight',
".block_sparse_moe.gate.weight")
assert gate_name not in new_weights
new_weights[gate_name] = p
else:
new_weights[n] = p
return mixtral.MixtralModel.load_weights(self, new_weights.items())
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
...@@ -331,7 +367,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -331,7 +367,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config # Required by MixtralForCausalLM
self.model = GraniteMoeModel(vllm_config=vllm_config, self.model = GraniteMoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
...@@ -403,37 +438,9 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -403,37 +438,9 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
new_weights = {} loader = AutoWeightsLoader(
for n, p in weights: self,
if n.endswith('.block_sparse_moe.input_linear.weight'): skip_prefixes=(["lm_head."]
for e in range(p.size(0)): if self.config.tie_word_embeddings else None),
w1_name = n.replace( )
'.block_sparse_moe.input_linear.weight', return loader.load_weights(weights)
f".block_sparse_moe.experts.{e}.w1.weight")
w3_name = n.replace(
'.block_sparse_moe.input_linear.weight',
f".block_sparse_moe.experts.{e}.w3.weight")
w1_param, w3_param = p[e].chunk(2, dim=0)
assert w1_name not in new_weights
assert w3_name not in new_weights
new_weights[w1_name] = w1_param
new_weights[w3_name] = w3_param
elif n.endswith('.block_sparse_moe.output_linear.weight'):
for e in range(p.size(0)):
w2_name = n.replace(
'.block_sparse_moe.output_linear.weight',
f".block_sparse_moe.experts.{e}.w2.weight")
w2_param = p[e]
assert w2_name not in new_weights
new_weights[w2_name] = w2_param
elif n.endswith('.block_sparse_moe.router.layer.weight'):
gate_name = n.replace('.block_sparse_moe.router.layer.weight',
".block_sparse_moe.gate.weight")
assert gate_name not in new_weights
new_weights[gate_name] = p
elif n == 'lm_head.weight' and self.config.tie_word_embeddings:
pass
else:
new_weights[n] = p
return mixtral.MixtralForCausalLM.load_weights(self,
new_weights.items())
...@@ -29,7 +29,7 @@ from vllm.sequence import IntermediateTensors ...@@ -29,7 +29,7 @@ from vllm.sequence import IntermediateTensors
from . import mixtral from . import mixtral
from .granitemoe import GraniteMoeAttention, GraniteMoeMoE from .granitemoe import GraniteMoeAttention, GraniteMoeMoE
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import make_layers, maybe_prefix from .utils import AutoWeightsLoader, make_layers, maybe_prefix
class GraniteMoeSharedMLP(nn.Module): class GraniteMoeSharedMLP(nn.Module):
...@@ -152,6 +152,8 @@ class GraniteMoeSharedModel(nn.Module): ...@@ -152,6 +152,8 @@ class GraniteMoeSharedModel(nn.Module):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config # Required by MixtralModel
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0 (lora_config.max_loras or 1)) if lora_config else 0
...@@ -207,6 +209,40 @@ class GraniteMoeSharedModel(nn.Module): ...@@ -207,6 +209,40 @@ class GraniteMoeSharedModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
new_weights = {}
for n, p in weights:
if n.endswith('.block_sparse_moe.input_linear.weight'):
for e in range(p.size(0)):
w1_name = n.replace(
'.block_sparse_moe.input_linear.weight',
f".block_sparse_moe.experts.{e}.w1.weight")
w3_name = n.replace(
'.block_sparse_moe.input_linear.weight',
f".block_sparse_moe.experts.{e}.w3.weight")
w1_param, w3_param = p[e].chunk(2, dim=0)
assert w1_name not in new_weights
assert w3_name not in new_weights
new_weights[w1_name] = w1_param
new_weights[w3_name] = w3_param
elif n.endswith('.block_sparse_moe.output_linear.weight'):
for e in range(p.size(0)):
w2_name = n.replace(
'.block_sparse_moe.output_linear.weight',
f".block_sparse_moe.experts.{e}.w2.weight")
w2_param = p[e]
assert w2_name not in new_weights
new_weights[w2_name] = w2_param
elif n.endswith('.block_sparse_moe.router.layer.weight'):
gate_name = n.replace('.block_sparse_moe.router.layer.weight',
".block_sparse_moe.gate.weight")
assert gate_name not in new_weights
new_weights[gate_name] = p
else:
new_weights[n] = p
return mixtral.MixtralModel.load_weights(self, new_weights.items())
class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
...@@ -234,7 +270,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -234,7 +270,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config
self.model = GraniteMoeSharedModel(vllm_config=vllm_config, self.model = GraniteMoeSharedModel(vllm_config=vllm_config,
prefix=maybe_prefix( prefix=maybe_prefix(
...@@ -307,37 +342,9 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -307,37 +342,9 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
new_weights = {} loader = AutoWeightsLoader(
for n, p in weights: self,
if n.endswith('.block_sparse_moe.input_linear.weight'): skip_prefixes=(["lm_head."]
for e in range(p.size(0)): if self.config.tie_word_embeddings else None),
w1_name = n.replace( )
'.block_sparse_moe.input_linear.weight', return loader.load_weights(weights)
f".block_sparse_moe.experts.{e}.w1.weight")
w3_name = n.replace(
'.block_sparse_moe.input_linear.weight',
f".block_sparse_moe.experts.{e}.w3.weight")
w1_param, w3_param = p[e].chunk(2, dim=0)
assert w1_name not in new_weights
assert w3_name not in new_weights
new_weights[w1_name] = w1_param
new_weights[w3_name] = w3_param
elif n.endswith('.block_sparse_moe.output_linear.weight'):
for e in range(p.size(0)):
w2_name = n.replace(
'.block_sparse_moe.output_linear.weight',
f".block_sparse_moe.experts.{e}.w2.weight")
w2_param = p[e]
assert w2_name not in new_weights
new_weights[w2_name] = w2_param
elif n.endswith('.block_sparse_moe.router.layer.weight'):
gate_name = n.replace('.block_sparse_moe.router.layer.weight',
".block_sparse_moe.gate.weight")
assert gate_name not in new_weights
new_weights[gate_name] = p
elif n == 'lm_head.weight' and self.config.tie_word_embeddings:
pass
else:
new_weights[n] = p
return mixtral.MixtralForCausalLM.load_weights(self,
new_weights.items())
...@@ -48,7 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -48,7 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -302,6 +302,8 @@ class Grok1Model(nn.Module): ...@@ -302,6 +302,8 @@ class Grok1Model(nn.Module):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0 (lora_config.max_loras or 1)) if lora_config else 0
...@@ -370,94 +372,6 @@ class Grok1Model(nn.Module): ...@@ -370,94 +372,6 @@ class Grok1Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Grok1Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.output_multiplier_scale = getattr(
config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
self.output_multiplier_scale)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
...@@ -480,9 +394,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -480,9 +394,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None and if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))): (scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales # Loading kv cache quantization scales
...@@ -553,13 +464,107 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -553,13 +464,107 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if "norm.scale" in name: if "norm.scale" in name:
name = name.replace("scale", "weight") name = name.replace("scale", "weight")
# Skip lm_head when tie_word_embeddings is True
if "lm_head" in name and self.config.tie_word_embeddings:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params
class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Grok1Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.output_multiplier_scale = getattr(
config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
self.output_multiplier_scale)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
skip_prefixes = ["rotary_emb.inv_freq"]
# Skip lm_head when tie_word_embeddings is True
if self.config.tie_word_embeddings:
skip_prefixes.append("lm_head")
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights)
...@@ -257,7 +257,7 @@ class H2OVLProcessor(BaseInternVLProcessor): ...@@ -257,7 +257,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
repl_features = IMG_CONTEXT * feature_size repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails(full=repl_full, features=repl_features) return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
def resolve_min_max_num( def resolve_min_max_num(
self, self,
...@@ -412,19 +412,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo): ...@@ -412,19 +412,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
**kwargs, **kwargs,
) )
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
max_tokens_one_image = self.get_max_image_tokens(use_msac=None)
if mm_counts.get("image", 0) <= 1:
max_tokens_per_image = max_tokens_one_image
else:
max_tokens_per_image = self.get_max_image_tokens(use_msac=False)
return {"image": max_tokens_per_image}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
...@@ -442,16 +429,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo): ...@@ -442,16 +429,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
use_msac=use_msac, use_msac=use_msac,
) )
def get_max_image_tokens(self, use_msac: Optional[bool] = None) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
use_msac=use_msac,
)
class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo] class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment