Commit fcfc474d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-dev

parents bb94d2e5 296c6572
...@@ -38,7 +38,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -38,7 +38,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (flatten_bn, is_pp_missing_parameter, from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
...@@ -927,7 +928,11 @@ class ChameleonModel(nn.Module): ...@@ -927,7 +928,11 @@ class ChameleonModel(nn.Module):
info=ChameleonProcessingInfo, info=ChameleonProcessingInfo,
dummy_inputs=ChameleonDummyInputsBuilder) dummy_inputs=ChameleonDummyInputsBuilder)
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP, SupportsQuant):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
......
...@@ -12,6 +12,7 @@ import os ...@@ -12,6 +12,7 @@ import os
import re import re
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -30,7 +31,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -30,7 +31,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -305,7 +306,12 @@ class GLMTransformer(nn.Module): ...@@ -305,7 +306,12 @@ class GLMTransformer(nn.Module):
return hidden_states return hidden_states
class ChatGLMModel(nn.Module): @support_torch_compile
class ChatGLMModel(nn.Module, SupportsQuant):
packed_modules_mapping = {
"linear_proj.merged_proj":
["linear_proj.gate_proj", "linear_proj.dense_h_to_4h"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
...@@ -458,7 +464,6 @@ class ChatGLMModel(nn.Module): ...@@ -458,7 +464,6 @@ class ChatGLMModel(nn.Module):
class ChatGLMBaseModel(nn.Module): class ChatGLMBaseModel(nn.Module):
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={".word_embeddings": ""}, ) orig_to_new_substr={".word_embeddings": ""}, )
...@@ -516,7 +521,8 @@ class ChatGLMBaseModel(nn.Module): ...@@ -516,7 +521,8 @@ class ChatGLMBaseModel(nn.Module):
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP): class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
SupportsQuant):
packed_modules_mapping = { packed_modules_mapping = {
"query_key_value": ["query_key_value"], "query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"] "dense_h_to_4h": ["dense_h_to_4h"]
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
from typing import Iterable, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
import torch.utils.checkpoint
from torch import nn from torch import nn
from transformers import CohereConfig from transformers import CohereConfig
...@@ -50,7 +49,7 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -50,7 +49,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
from .utils import (extract_layer_index, is_pp_missing_parameter, from .utils import (extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -333,7 +332,7 @@ class CohereModel(nn.Module): ...@@ -333,7 +332,7 @@ class CohereModel(nn.Module):
return hidden_states return hidden_states
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
......
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
class ConstantSizeCache(ABC):
"""
Abstract base class for managing constant size caches
like Mamba and Minimax.
"""
def __init__(self, max_batch_size: int):
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the cache
self.cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size))
@property
@abstractmethod
def cache(self) -> Any:
"""Return the underlying cache tensor(s)"""
pass
@abstractmethod
def _copy_cache(self, from_index: int, to_index: int):
"""Copy cache data from one index to another"""
pass
def current_run_tensors(self, **kwargs) -> Tuple:
"""
Return the tensors for the current run's conv and ssm state.
"""
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
state_indices_tensor = torch.as_tensor(state_indices,
dtype=torch.int32,
device="cuda")
cache_tensors = self.cache
else:
# CUDA graph capturing runs
cache_tensors, state_indices_tensor = kwargs[
"seqlen_agnostic_capture_inputs"]
return (cache_tensors, state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
assert "seqlen_agnostic_capture_inputs" in input_buffers
_, input_state_indices_buffer = input_buffers[
"seqlen_agnostic_capture_inputs"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
state_indices)
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
input_state_indices_buffer.copy_(
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Cache during the CUDA graph replay
runs.
"""
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
return (self.cache, state_indices_tensor)
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
finished_requests_ids) -> int:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if cur_rid in finished_requests_ids:
# set as pad, do not allocate destination index
return PAD_SLOT_ID
elif cur_rid not in self.cache_indices_mapping:
destination_index = self.free_cache_indices.pop()
self.cache_indices_mapping[cur_rid] = {seq_id: destination_index}
return destination_index
elif seq_id not in (seq_ids2indices :=
self.cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists = next(iter(seq_ids2indices.values()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index = self.free_cache_indices.pop()
self._copy_cache(from_index=index_exists,
to_index=destination_index)
self.cache_indices_mapping[cur_rid][seq_id] = destination_index
return destination_index
else:
return self.cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]) -> List[int]:
return [
self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
def _release_finished_requests(self,
finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.cache_indices_mapping:
for seq_id in self.cache_indices_mapping[req_id]:
self.free_cache_indices.append(
self.cache_indices_mapping[req_id][seq_id])
self.cache_indices_mapping.pop(req_id)
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 DeciAI Research Team. All rights reserved.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on MistralAI GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeciLM model compatible with HuggingFace weights."""
from typing import Iterable, Set, Tuple
import torch
from vllm.config import VllmConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM
from .utils import is_pp_missing_parameter
class DeciLMForCausalLM(LlamaForCausalLM):
"""
Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct.
Based on the llama executor.
The main difference is that DeciLM uses Variable Grouped Query Attention.
The constant number of GQA heads in the decoder is overridden with a value
per layer.
Usually, in the HuggingFace implementation, instead of
"config.num_key_value_heads", we use
"config.num_key_value_heads_per_layer[i]" which varies.
Currently, PagedAttention does not work well with variable GQA, so we
normalize the weights upon loading, and use uniform GQA with the max value
instead.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer")
super().__init__(vllm_config=vllm_config)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "k_proj" in name or "v_proj" in name:
loaded_weight = self._degroup_weight(loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
hidden_size = self.config.hidden_size
head_size = self.config.hidden_size // self.config.num_attention_heads
target_num_kv_heads = self.config.num_key_value_heads
num_kv_heads = loaded_weight.shape[0] // head_size
n_repeats = target_num_kv_heads / num_kv_heads
assert n_repeats == int(n_repeats)
n_repeats = int(n_repeats)
loaded_weight = loaded_weight.view(num_kv_heads, head_size,
hidden_size)
loaded_weight = torch.repeat_interleave(loaded_weight,
repeats=n_repeats,
dim=0)
loaded_weight = loaded_weight.reshape(target_num_kv_heads * head_size,
hidden_size)
return loaded_weight
...@@ -509,7 +509,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -509,7 +509,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
_, hw, n_dim = images_embeds.shape _, hw, n_dim = images_embeds.shape
h = w = int(hw**0.5) h = w = int(hw**0.5)
# 根据self.tile_tag & self.global_view_pos填充image token sequence # fill image token based on self.tile_tag & self.global_view_pos
tile_index = 0 tile_index = 0
vision_embeddings = [] vision_embeddings = []
for jdx in range(images_spatial_crop.size(0)): for jdx in range(images_spatial_crop.size(0)):
......
...@@ -7,6 +7,7 @@ import torch.nn as nn ...@@ -7,6 +7,7 @@ import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -59,7 +60,15 @@ class EAGLE(nn.Module): ...@@ -59,7 +60,15 @@ class EAGLE(nn.Module):
truncated_vocab_size < vocab_size. To use this technique, one has to find truncated_vocab_size < vocab_size. To use this technique, one has to find
the top-k most frequent tokens in target dataset and add that as a tensor the top-k most frequent tokens in target dataset and add that as a tensor
in the draft checkpoint (using key token_map). Also, the draft config in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute.""" needs to have truncated_vocab_size (=k) as an attribute.
4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP
module with regards to the use of additional RMS norms. The original
EAGLE architecture 1) skips the pre-attention norm in its first
transformer block, and 2) skips the final output norm, both of which we
found to be suboptimal. We also add the support for separate norms
applying to both the token embedding and hidden states before projection
as in DeepSeek MTP, which we found to improve performance as well.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
...@@ -81,9 +90,22 @@ class EAGLE(nn.Module): ...@@ -81,9 +90,22 @@ class EAGLE(nn.Module):
# While weights and biases are generally not needed, # While weights and biases are generally not needed,
# they are retained here to support certain unit tests # they are retained here to support certain unit tests
# (e.g., spec_decode/e2e/test_eagle_correctness.py). # (e.g., spec_decode/e2e/test_eagle_correctness.py).
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm( if not hasattr(self.config.model,
weight=self.model.model.layers[0].input_layernorm.weight) "skip_prenorm") or self.config.model.skip_prenorm:
self.model.model.norm = DummyOutputNorm() self.model.model.layers[0].input_layernorm = DummyInputLayerNorm(
weight=self.model.model.layers[0].input_layernorm.weight)
if not hasattr(
self.config.model,
"skip_output_norm") or self.config.model.skip_output_norm:
self.model.model.norm = DummyOutputNorm()
self.add_para_norm = False
if hasattr(self.config.model,
"add_para_norm") and self.config.model.add_para_norm:
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.add_para_norm = True
self.orig_vocab_size = config.vocab_size self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size self.truncated_vocab_size = config.truncated_vocab_size
...@@ -128,8 +150,17 @@ class EAGLE(nn.Module): ...@@ -128,8 +150,17 @@ class EAGLE(nn.Module):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.get_input_embeddings(input_ids)
inputs_embeds = self.fc( if self.add_para_norm:
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) inputs_embeds = torch.cat([
self.enorm(inputs_embeds),
self.hnorm(previous_hidden_states)
],
dim=-1)
else:
inputs_embeds = torch.cat([inputs_embeds, previous_hidden_states],
dim=-1)
inputs_embeds = self.fc(inputs_embeds)
inputs_embeds[positions == 0] = 0 # masking inputs at position=0 inputs_embeds[positions == 0] = 0 # masking inputs at position=0
...@@ -190,6 +221,14 @@ class EAGLE(nn.Module): ...@@ -190,6 +221,14 @@ class EAGLE(nn.Module):
else: else:
logger.warning_once("Found bias in the loaded weights but " logger.warning_once("Found bias in the loaded weights but "
"the model config doesn't have bias.") "the model config doesn't have bias.")
elif name.startswith("enorm.weight"):
weight_loader = getattr(self.enorm.weight, "weight_loader",
default_weight_loader)
weight_loader(self.enorm.weight, loaded_weight)
elif name.startswith("hnorm.weight"):
weight_loader = getattr(self.hnorm.weight, "weight_loader",
default_weight_loader)
weight_loader(self.hnorm.weight, loaded_weight)
elif name.startswith("model.lm_head.") or name.startswith( elif name.startswith("model.lm_head.") or name.startswith(
"model.model."): "model.model."):
model_weights[name.split("model.", 1)[-1]] = loaded_weight model_weights[name.split("model.", 1)[-1]] = loaded_weight
......
...@@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors ...@@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.exaone import ExaoneConfig from vllm.transformers_utils.configs.exaone import ExaoneConfig
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -313,6 +313,7 @@ class ExaoneModel(nn.Module): ...@@ -313,6 +313,7 @@ class ExaoneModel(nn.Module):
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config self.config = config
self.quant_config = quant_config
lora_vocab = ((lora_config.lora_extra_vocab_size * lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0) (lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab self.vocab_size = config.vocab_size + lora_vocab
...@@ -384,6 +385,72 @@ class ExaoneModel(nn.Module): ...@@ -384,6 +385,72 @@ class ExaoneModel(nn.Module):
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".c_fc_0", 0),
(".gate_up_proj", ".c_fc_1", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
...@@ -481,71 +548,12 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -481,71 +548,12 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ loader = AutoWeightsLoader(
# (param_name, shard_name, shard_id) self,
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".c_fc_0", 0),
(".gate_up_proj", ".c_fc_1", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight # With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is # The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc. # processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name: skip_prefixes=(["lm_head."]
continue if self.config.tie_word_embeddings else None),
if (self.quant_config is not None and )
(scale_name := self.quant_config.get_cache_scale(name))): return loader.load_weights(weights)
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors ...@@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import RWConfig from vllm.transformers_utils.configs import RWConfig
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -382,6 +382,17 @@ class FalconModel(nn.Module): ...@@ -382,6 +382,17 @@ class FalconModel(nn.Module):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"], make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size)) config.hidden_size))
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings(input_ids) return self.word_embeddings(input_ids)
...@@ -407,82 +418,6 @@ class FalconModel(nn.Module): ...@@ -407,82 +418,6 @@ class FalconModel(nn.Module):
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
class FalconForCausalLM(nn.Module, SupportsPP):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.transformer = FalconModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
# only Falcon-11B doesn't share lm_head weight with word embeddings
# and previous Falcon model doesn't have tie_word_embeddings config
# so we set tie_word_embeddings to True by default
self.tie_word_embeddings = (config.tie_word_embeddings
if config.tie_word_embeddings is not None
else True)
if self.tie_word_embeddings:
self.lm_head = self.transformer.word_embeddings
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions,
intermediate_tensors, inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
total_num_heads = self.config.num_attention_heads total_num_heads = self.config.num_attention_heads
...@@ -496,9 +431,6 @@ class FalconForCausalLM(nn.Module, SupportsPP): ...@@ -496,9 +431,6 @@ class FalconForCausalLM(nn.Module, SupportsPP):
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name == "lm_head.weight" and self.tie_word_embeddings:
# Falcon uses tied embeddings except Falcon-11b.
continue
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
...@@ -563,8 +495,78 @@ class FalconForCausalLM(nn.Module, SupportsPP): ...@@ -563,8 +495,78 @@ class FalconForCausalLM(nn.Module, SupportsPP):
weight.data.copy_(_weight) weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1) weight.data=weight.data.reshape(ori_shape[1], -1)
return loaded_params return loaded_params
class FalconForCausalLM(nn.Module, SupportsPP):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.transformer = FalconModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
# only Falcon-11B doesn't share lm_head weight with word embeddings
# and previous Falcon model doesn't have tie_word_embeddings config
# so we set tie_word_embeddings to True by default
self.tie_word_embeddings = (config.tie_word_embeddings
if config.tie_word_embeddings is not None
else True)
if self.tie_word_embeddings:
self.lm_head = self.transformer.word_embeddings
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions,
intermediate_tensors, inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
...@@ -875,7 +875,8 @@ class Florence2MultiModalProcessor( ...@@ -875,7 +875,8 @@ class Florence2MultiModalProcessor(
Florence2MultiModalProcessor, Florence2MultiModalProcessor,
info=Florence2ProcessingInfo, info=Florence2ProcessingInfo,
dummy_inputs=Florence2DummyInputsBuilder) dummy_inputs=Florence2DummyInputsBuilder)
class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal): class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
""" PyTorch Fuyu model.""" """ PyTorch Fuyu model."""
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict from typing import Literal, Optional, Set, Tuple, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -43,6 +43,7 @@ from vllm.sequence import IntermediateTensors ...@@ -43,6 +43,7 @@ from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
# Cannot find the following 2 numbers from hf config. # Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011 _IMAGE_TOKEN_ID = 71011
...@@ -65,6 +66,14 @@ class FuyuImagePatchInputs(TypedDict): ...@@ -65,6 +66,14 @@ class FuyuImagePatchInputs(TypedDict):
flattened just like `flat_data`. flattened just like `flat_data`.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class FuyuProcessingInfo(BaseProcessingInfo): class FuyuProcessingInfo(BaseProcessingInfo):
...@@ -183,6 +192,19 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): ...@@ -183,6 +192,19 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
processed_outputs["image_patches"] = image_patches[0] processed_outputs["image_patches"] = image_patches[0]
# get patch grid size for each image
embed_is_patch = []
for image in images:
ncols, nrows = self.info.get_image_feature_grid_size(
image_width=image.width,
image_height=image.height,
)
mask = torch.tensor(([True] * ncols + [False]) * nrows)
embed_is_patch.append(mask)
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs return processed_outputs
def _apply_hf_processor_tokens_only( def _apply_hf_processor_tokens_only(
...@@ -202,7 +224,8 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): ...@@ -202,7 +224,8 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict(image_patches=MultiModalFieldConfig.batched("image")) return dict(image_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates( def _get_prompt_updates(
self, self,
...@@ -306,13 +329,20 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -306,13 +329,20 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of image patches. " raise ValueError("Incorrect type of image patches. "
f"Got type: {type(image_patches)}") f"Got type: {type(image_patches)}")
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
image_patches_flat = flatten_bn(image_patches) image_patches_flat = flatten_bn(image_patches)
embed_is_patch = flatten_bn(embed_is_patch)
return FuyuImagePatchInputs( return FuyuImagePatchInputs(
type="image_patches", type="image_patches",
flat_data=self._validate_pixel_values( flat_data=self._validate_pixel_values(
flatten_bn(image_patches_flat, concat=True)), flatten_bn(image_patches_flat, concat=True)),
patches_per_image=[x.size(0) for x in image_patches_flat], patches_per_image=[x.size(0) for x in image_patches_flat],
embed_is_patch=embed_is_patch,
) )
return None return None
...@@ -325,6 +355,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -325,6 +355,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
assert self.vision_embed_tokens is not None assert self.vision_embed_tokens is not None
vision_embeddings_flat, _ = self.vision_embed_tokens( vision_embeddings_flat, _ = self.vision_embed_tokens(
image_patches_flat) image_patches_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0) return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_multimodal_embeddings( def get_multimodal_embeddings(
...@@ -332,8 +363,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -332,8 +363,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings image_features = self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -343,8 +379,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -343,8 +379,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds,
_IMAGE_TOKEN_ID) select_patch_features(multimodal_embeddings), _IMAGE_TOKEN_ID)
return inputs_embeds return inputs_embeds
def forward( def forward(
......
...@@ -59,16 +59,23 @@ class Gemma3MLP(nn.Module): ...@@ -59,16 +59,23 @@ class Gemma3MLP(nn.Module):
intermediate_size: int, intermediate_size: int,
hidden_activation: str, hidden_activation: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False, bias=False,
quant_config=quant_config) quant_config=quant_config,
self.down_proj = RowParallelLinear(intermediate_size, prefix=f"{prefix}.down_proj",
hidden_size, )
bias=False,
quant_config=quant_config)
if hidden_activation != "gelu_pytorch_tanh": if hidden_activation != "gelu_pytorch_tanh":
raise ValueError( raise ValueError(
"Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
...@@ -125,12 +132,14 @@ class Gemma3Attention(nn.Module): ...@@ -125,12 +132,14 @@ class Gemma3Attention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=config.attention_bias, bias=config.attention_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=config.attention_bias, bias=config.attention_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj",
) )
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
...@@ -293,6 +302,7 @@ class Gemma3DecoderLayer(nn.Module): ...@@ -293,6 +302,7 @@ class Gemma3DecoderLayer(nn.Module):
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_activation=config.hidden_activation, hidden_activation=config.hidden_activation,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp",
) )
self.input_layernorm = GemmaRMSNorm(config.hidden_size, self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -344,6 +354,7 @@ class Gemma3Model(nn.Module): ...@@ -344,6 +354,7 @@ class Gemma3Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
prefix=f"{prefix}.embed_tokens",
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
......
...@@ -30,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -30,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
...@@ -60,12 +59,9 @@ class Gemma3ImagePixelInputs(TypedDict): ...@@ -60,12 +59,9 @@ class Gemma3ImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond A boolean mask indicating which image embeddings correspond
to patch tokens. to patch tokens.
Shape: `(batch_size, num_images, num_embeds)` Shape: `(batch_size * num_images, num_embeds)`
""" """
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
Gemma3ImageInputs = Gemma3ImagePixelInputs Gemma3ImageInputs = Gemma3ImagePixelInputs
...@@ -295,8 +291,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ...@@ -295,8 +291,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
# HF processor pops the `num_crops` kwarg, which is needed by vLLM # HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None: if (images := mm_data.get("images")) is not None:
assert isinstance(images, list)
parsed_images = (self._get_data_parser().parse_mm_data({ parsed_images = (self._get_data_parser().parse_mm_data({
"image": "image":
images images
...@@ -319,11 +313,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ...@@ -319,11 +313,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
tokenizer.encode(image_repl, add_special_tokens=False) tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features for image_repl in image_repl_features
] ]
num_embeds = [
len(image_repl_feature_tokens)
for image_repl_feature_tokens in image_repls_feature_tokens
]
processed_outputs["num_embeds"] = torch.tensor(num_embeds)
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
image_token_id = vocab[tokenizer.image_token] image_token_id = vocab[tokenizer.image_token]
...@@ -356,7 +345,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ...@@ -356,7 +345,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
"image", num_crops + 1), "image", num_crops + 1),
num_crops=MultiModalFieldConfig.batched("image"), num_crops=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"), embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_updates( def _get_prompt_updates(
...@@ -585,7 +573,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -585,7 +573,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
num_crops = kwargs.pop("num_crops", None) num_crops = kwargs.pop("num_crops", None)
embed_is_patch = kwargs.pop("embed_is_patch", None) embed_is_patch = kwargs.pop("embed_is_patch", None)
num_embeds = kwargs.pop("num_embeds", None)
image_embeds = kwargs.pop("image_embeds", None) image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds." assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None: if pixel_values is None:
...@@ -603,19 +590,15 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -603,19 +590,15 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
raise ValueError("Incorrect type of embed_is_patch. " raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}") f"Got type: {type(embed_is_patch)}")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
pixel_values = flatten_bn(pixel_values, concat=True) pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True) num_crops = flatten_bn(num_crops, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return Gemma3ImagePixelInputs( return Gemma3ImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values), pixel_values=self._validate_pixel_values(pixel_values),
num_patches=num_crops + 1, num_patches=num_crops + 1,
embed_is_patch=embed_is_patch, embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
) )
def _image_pixels_to_features( def _image_pixels_to_features(
...@@ -630,7 +613,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -630,7 +613,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def _process_image_input( def _process_image_input(
self, self,
image_input: Gemma3ImageInputs, image_input: Gemma3ImageInputs,
) -> tuple[torch.Tensor, ...]: ) -> list[torch.Tensor]:
assert self.vision_tower is not None assert self.vision_tower is not None
pixel_values = image_input["pixel_values"] pixel_values = image_input["pixel_values"]
...@@ -642,7 +625,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -642,7 +625,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
) )
image_embeds = self.multi_modal_projector(image_features) image_embeds = self.multi_modal_projector(image_features)
return image_embeds.split(num_patches.tolist()) return [
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
]
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
...@@ -652,15 +637,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -652,15 +637,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
image_features = self._process_image_input(image_input) image_features = self._process_image_input(image_input)
if kwargs.get("v0_path", False): return scatter_patch_features(
return image_features image_features,
image_input["embed_is_patch"],
return flatten_2d_lists( )
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -689,7 +669,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -689,7 +669,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(input_ids,
......
...@@ -44,7 +44,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -44,7 +44,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
from vllm.utils import is_hip,W8a8GetCacheJSON from vllm.utils import is_hip,W8a8GetCacheJSON
...@@ -219,6 +219,13 @@ class GPTNeoXModel(nn.Module): ...@@ -219,6 +219,13 @@ class GPTNeoXModel(nn.Module):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"], make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size)) config.hidden_size))
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.tritonsingleton= W8a8GetCacheJSON()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_in(input_ids) return self.embed_in(input_ids)
...@@ -244,67 +251,6 @@ class GPTNeoXModel(nn.Module): ...@@ -244,67 +251,6 @@ class GPTNeoXModel(nn.Module):
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
return hidden_states return hidden_states
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "gpt_neox"))
self.embed_out = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.embed_out.weight = self.gpt_neox.embed_in.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.gpt_neox.make_empty_intermediate_tensors)
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.tritonsingleton= W8a8GetCacheJSON()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.gpt_neox.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.gpt_neox(input_ids, positions,
intermediate_tensors, inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.embed_out, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
...@@ -342,6 +288,7 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): ...@@ -342,6 +288,7 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
#当为triton支持推理的时候不能进行处理 #当为triton支持推理的时候不能进行处理
if self.quant_method == "compressed_tensors": if self.quant_method == "compressed_tensors":
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
...@@ -387,6 +334,64 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): ...@@ -387,6 +334,64 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
n=int(key.split('_')[1]) n=int(key.split('_')[1])
k=int(key.split('_')[2]) k=int(key.split('_')[2])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value) ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
return loaded_params return loaded_params
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "gpt_neox"))
self.embed_out = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.embed_out.weight = self.gpt_neox.embed_in.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.gpt_neox.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.gpt_neox.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.gpt_neox(input_ids, positions,
intermediate_tensors, inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.embed_out, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
...@@ -17,16 +17,14 @@ ...@@ -17,16 +17,14 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Dict, List, Literal, Optional, Set, Tuple, TypedDict, Union from typing import Dict, Literal, Optional, Set, Tuple, TypedDict, Union
import torch import torch
import torch.utils.checkpoint
from torch import nn from torch import nn
from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor, from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
Idefics3Processor) Idefics3Processor)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -35,13 +33,16 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead ...@@ -35,13 +33,16 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.parse import ImageProcessorItems, ImageSize
from vllm.multimodal.parse import ImageProcessorItems # yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
MultiModalDataItems, MultiModalDataItems,
MultiModalFieldConfig, MultiModalFieldConfig,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate,
encode_tokens)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -53,18 +54,28 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal ...@@ -53,18 +54,28 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
logger = init_logger(__name__)
class Idefics3ImagePixelInputs(TypedDict): class Idefics3ImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: torch.Tensor pixel_values: torch.Tensor
""" """
Shape: `(batch_size * num_images * num_patches, Shape: `(batch_size * num_images * num_patches,
num_channels, height, width)` num_channels, height, width)`
""" """
pixel_attention_mask: Optional[torch.BoolTensor] pixel_attention_mask: torch.Tensor
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class Idefics3ImageEmbeddingInputs(TypedDict): class Idefics3ImageEmbeddingInputs(TypedDict):
...@@ -75,6 +86,14 @@ class Idefics3ImageEmbeddingInputs(TypedDict): ...@@ -75,6 +86,14 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs] ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
...@@ -100,32 +119,14 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -100,32 +119,14 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> Mapping[str, int]: ) -> Mapping[str, int]:
hf_processor = self.get_hf_processor() return {"image": self.get_max_image_tokens()}
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
grid_w, grid_h = self._get_image_feature_grid_size(
image_width=image_processor.size['longest_edge'],
image_height=image_processor.size['longest_edge'],
)
num_image_token = (grid_w * grid_h + 1) * hf_processor.image_seq_len
# Calculate Non-image-token length
# NOTE: <row_1_col_1> and <global-img> are special token for SmolVLM
# but not for Idefic3, so we need to tokenize them to get actual length.
tokenizer = self.get_tokenizer()
tile_token_len = len(tokenizer.tokenize("<row_1_col_1>"))
glob_token_len = len(tokenizer.tokenize(hf_processor.global_image_tag))
# linebreak and <fake_token_around_image> always cost 1 token
fake_token_len = lb_len = 1
non_image_token = (grid_w * grid_h) * (
tile_token_len + fake_token_len) + glob_token_len + (
grid_h + 1) * lb_len + fake_token_len
return {"image": num_image_token + non_image_token}
def _resize_output_size(self, def _resize_output_size(self,
*, *,
height: int, height: int,
width: int, width: int,
max_len: Optional[int] = None, max_len: Optional[int] = None,
min_len: Optional[int] = 1, min_len: int = 1,
max_size: Optional[int] = None) -> tuple[int, int]: max_size: Optional[int] = None) -> tuple[int, int]:
# Set default value for max_len if not provided # Set default value for max_len if not provided
max_len = max(height, width) if max_len is None else max_len max_len = max(height, width) if max_len is None else max_len
...@@ -181,10 +182,13 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -181,10 +182,13 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
size: Optional[dict[str, object]] = None, processor: Optional[Idefics3Processor],
) -> tuple[int, int]: ) -> tuple[int, int]:
hf_processor = self.get_hf_processor(size=size) if processor is None:
image_processor: Idefics3ImageProcessor = hf_processor.image_processor processor = self.get_hf_processor()
image_processor: Idefics3ImageProcessor = processor.image_processor
max_image_size = image_processor.max_image_size['longest_edge'] max_image_size = image_processor.max_image_size['longest_edge']
size = image_processor.size['longest_edge'] size = image_processor.size['longest_edge']
assert size % max_image_size == 0, ( assert size % max_image_size == 0, (
...@@ -204,6 +208,105 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -204,6 +208,105 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
grid_h = grid_w = 0 grid_h = grid_w = 0
return grid_w, grid_h return grid_w, grid_h
def get_num_patches(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Idefics3Processor],
) -> int:
grid_w, grid_h = self._get_image_feature_grid_size(
image_width=image_width,
image_height=image_height,
processor=processor,
)
return grid_w * grid_h + 1
def get_image_repl(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Idefics3Processor],
) -> str:
if processor is None:
processor = self.get_hf_processor()
image_token = processor.image_token.content
fake_image_token = processor.fake_image_token.content
global_img_token = processor.global_image_tag
image_seq_len = processor.image_seq_len
grid_placeholder = "<row_{n_h}_col_{n_w}>"
p_img = image_token * image_seq_len
global_img_placeholder = fake_image_token + global_img_token + p_img
tile_img_placeholder = fake_image_token + grid_placeholder + p_img
grid_w, grid_h = self._get_image_feature_grid_size(
image_width=image_width,
image_height=image_height,
processor=processor,
)
if grid_w == 0 and grid_h == 0:
return global_img_placeholder + fake_image_token
tiles_placeholder = list[str]()
for i in range(grid_h):
for j in range(grid_w):
placeholder_per_tile = tile_img_placeholder.format(n_h=i + 1,
n_w=j + 1)
tiles_placeholder.append(placeholder_per_tile)
# Add line break if it is the last tile in the row
if j == grid_w - 1:
tiles_placeholder.append("\n")
return "".join([
*tiles_placeholder,
"\n",
global_img_placeholder,
fake_image_token,
])
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Idefics3Processor],
) -> int:
tokenizer = self.get_tokenizer()
image_repl = self.get_image_repl(
image_width=image_width,
image_height=image_height,
processor=processor,
)
image_repl_tokens = encode_tokens(
tokenizer,
image_repl,
add_special_tokens=False,
)
return len(image_repl_tokens)
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
image_processor: Idefics3ImageProcessor = processor.image_processor
return ImageSize(
width=image_processor.size["longest_edge"],
height=image_processor.size["longest_edge"],
)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
): ):
...@@ -217,7 +320,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] ...@@ -217,7 +320,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
hf_processor = self.info.get_hf_processor() hf_processor = self.info.get_hf_processor()
image_processor: Idefics3ImageProcessor = hf_processor.image_processor image_processor: Idefics3ImageProcessor = hf_processor.image_processor
longest_edge = image_processor.max_image_size['longest_edge'] longest_edge = image_processor.max_image_size['longest_edge']
image_token: str = hf_processor.image_token.content image_token = hf_processor.image_token.content
mm_data = { mm_data = {
"image": "image":
...@@ -232,7 +335,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] ...@@ -232,7 +335,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
) )
class Idefics3MultimodalProcessor( class Idefics3MultiModalProcessor(
BaseMultiModalProcessor[Idefics3ProcessingInfo]): BaseMultiModalProcessor[Idefics3ProcessingInfo]):
def _call_hf_processor( def _call_hf_processor(
...@@ -241,26 +344,61 @@ class Idefics3MultimodalProcessor( ...@@ -241,26 +344,61 @@ class Idefics3MultimodalProcessor(
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
if mm_data: # Text-only input not supported in composite processor
processed_outputs = super()._call_hf_processor( if not (images := mm_data.get("images", [])):
prompt, mm_data, mm_kwargs) prompt_ids = self.info.get_tokenizer().encode(prompt)
image_grids = [ prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
self.info._get_image_feature_grid_size( return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
image_width=img.width,
image_height=img.height, processed_outputs = super()._call_hf_processor(
**mm_kwargs, prompt,
) for img in mm_data["images"] mm_data,
] mm_kwargs,
image_patches = list(map(lambda x: math.prod(x) + 1, image_grids)) )
for key in ("pixel_values", "pixel_attention_mask"):
data = processed_outputs.pop(key) parsed_images = (self._get_data_parser().parse_mm_data({
data = data.flatten(0, 1).split(image_patches) "image": images
processed_outputs[key] = data }).get_items("image", ImageProcessorItems))
else: image_sizes = [
tokenizer = self.info.get_tokenizer() parsed_images.get_image_size(i) for i in range(len(parsed_images))
processed_outputs = tokenizer(prompt, ]
add_special_tokens=True, hf_processor = self.info.get_hf_processor(**mm_kwargs)
return_tensors="pt")
image_repl_features = [
self.info.get_image_repl(image_width=size.width,
image_height=size.height,
processor=hf_processor)
for size in image_sizes
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
vocab = tokenizer.get_vocab()
image_token_id = vocab[hf_processor.image_token.content]
embed_is_patch = [
torch.tensor(image_repl_tokens) == image_token_id
for image_repl_tokens in image_repls_feature_tokens
]
processed_outputs["embed_is_patch"] = embed_is_patch
num_patches = [
self.info.get_num_patches(
image_width=size.width,
image_height=size.height,
processor=hf_processor,
) for size in image_sizes
]
processed_outputs["num_patches"] = torch.tensor(num_patches)
# Remove the extra batch dimension
processed_outputs["pixel_values"].squeeze_(0)
processed_outputs["pixel_attention_mask"].squeeze_(0)
return processed_outputs return processed_outputs
def _get_mm_fields_config( def _get_mm_fields_config(
...@@ -268,10 +406,16 @@ class Idefics3MultimodalProcessor( ...@@ -268,10 +406,16 @@ class Idefics3MultimodalProcessor(
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
num_patches = hf_inputs.get("num_patches", torch.empty(0))
return dict( return dict(
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.flat_from_sizes(
pixel_attention_mask=MultiModalFieldConfig.batched("image"), "image", num_patches),
pixel_attention_mask=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_updates( def _get_prompt_updates(
...@@ -281,42 +425,18 @@ class Idefics3MultimodalProcessor( ...@@ -281,42 +425,18 @@ class Idefics3MultimodalProcessor(
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token.content image_token = hf_processor.image_token.content
fake_image_token = hf_processor.fake_image_token.content
global_img_token = hf_processor.global_image_tag
image_seq_len = hf_processor.image_seq_len
grid_placeholder = "<row_{n_h}_col_{n_w}>"
p_img = image_token * image_seq_len
global_img_placeholder = fake_image_token + global_img_token + p_img
tile_img_placeholder = fake_image_token + grid_placeholder + p_img
def get_replacement_idefics3(item_idx: int) -> str: def get_replacement_idefics3(item_idx: int) -> str:
images = mm_items.get_items("image", ImageProcessorItems) images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx) image_size = images.get_image_size(item_idx)
grid_w, grid_h = self.info._get_image_feature_grid_size(
return self.info.get_image_repl(
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
**hf_processor_mm_kwargs, processor=hf_processor,
) )
if grid_w == 0 and grid_h == 0:
image_placeholder = global_img_placeholder
else:
tiles_placeholder = list[str]()
for i in range(grid_h):
for j in range(grid_w):
placeholder_per_tile = tile_img_placeholder.format(
n_h=i + 1, n_w=j + 1)
tiles_placeholder.append(placeholder_per_tile)
# Add line break if it is the last tile in the row
if j == grid_w - 1:
tiles_placeholder.append("\n")
image_placeholder = "".join(
[*tiles_placeholder, "\n", global_img_placeholder])
return image_placeholder + fake_image_token
return [ return [
PromptReplacement( PromptReplacement(
...@@ -424,73 +544,13 @@ class Idefics3Model(nn.Module): ...@@ -424,73 +544,13 @@ class Idefics3Model(nn.Module):
config.vision_config.patch_size)**2) / (config.scale_factor**2)) config.vision_config.patch_size)**2) / (config.scale_factor**2))
self.image_token_id = self.config.image_token_id self.image_token_id = self.config.image_token_id
def _validate_pixel_values( def image_pixels_to_features(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
pixel_attention_mask = kwargs.pop("pixel_attention_mask", None)
if pixel_values is None and image_embeds is None:
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Idefics3ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if isinstance(pixel_values, list):
pixel_values = torch.cat(pixel_values, dim=1)
pixel_attention_mask = torch.cat(pixel_attention_mask, dim=1)
else:
pixel_values = flatten_bn(pixel_values)
pixel_attention_mask = flatten_bn(pixel_attention_mask)
return Idefics3ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
pixel_attention_mask=pixel_attention_mask)
raise AssertionError("This line should be unreachable.")
def _image_pixels_to_features(
self, self,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,
pixel_attention_mask: Optional[torch.BoolTensor] = None, pixel_attention_mask: torch.Tensor,
) -> NestedTensors: ) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since # NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower # this is already done inside the vision tower
num_patches = [x.size(0) for x in pixel_values]
pixel_values = pixel_values.to( pixel_values = pixel_values.to(
dtype=self.vision_model.embeddings.patch_embedding.weight.dtype dtype=self.vision_model.embeddings.patch_embedding.weight.dtype
) # fp16 compatibility ) # fp16 compatibility
...@@ -502,17 +562,9 @@ class Idefics3Model(nn.Module): ...@@ -502,17 +562,9 @@ class Idefics3Model(nn.Module):
pixel_values = pixel_values[real_images_inds].contiguous() pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask # Handle the vision attention mask
if pixel_attention_mask is None: # Remove padding images from the mask
pixel_attention_mask = torch.ones( pixel_attention_mask = pixel_attention_mask[
size=(pixel_values.size(0), pixel_values.size(2), real_images_inds].contiguous()
pixel_values.size(3)),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask
pixel_attention_mask = pixel_attention_mask[
real_images_inds].contiguous()
patch_size = self.config.vision_config.patch_size patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(dimension=1, patches_subgrid = pixel_attention_mask.unfold(dimension=1,
...@@ -529,27 +581,7 @@ class Idefics3Model(nn.Module): ...@@ -529,27 +581,7 @@ class Idefics3Model(nn.Module):
patch_attention_mask=patch_attention_mask, patch_attention_mask=patch_attention_mask,
) )
return image_hidden_states.split(num_patches) return image_hidden_states
def _process_image_pixels(
self, inputs: Idefics3ImagePixelInputs) -> NestedTensors:
assert self.vision_model is not None
pixel_values = inputs["data"]
pixel_attention_mask = inputs["pixel_attention_mask"]
return self._image_pixels_to_features(pixel_values,
pixel_attention_mask)
def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_features = self._process_image_pixels(image_input)
num_patches = [x.size(0) for x in image_features]
image_features = torch.cat(image_features)
return self.connector(image_features).split(num_patches)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -575,7 +607,7 @@ class Idefics3Model(nn.Module): ...@@ -575,7 +607,7 @@ class Idefics3Model(nn.Module):
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
Idefics3MultimodalProcessor, Idefics3MultiModalProcessor,
info=Idefics3ProcessingInfo, info=Idefics3ProcessingInfo,
dummy_inputs=Idefics3DummyInputsBuilder) dummy_inputs=Idefics3DummyInputsBuilder)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...@@ -616,13 +648,118 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -616,13 +648,118 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.logits_processor = LogitsProcessor(config.text_config.vocab_size) self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
self.sampler = get_sampler() self.sampler = get_sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Idefics3ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
embed_is_patch=embed_is_patch,
)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
pixel_attention_mask = kwargs.pop("pixel_attention_mask")
if not isinstance(pixel_attention_mask, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel_attention_mask. "
f"Got type: {type(pixel_attention_mask)}")
num_patches = kwargs.pop("num_patches")
if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
pixel_values = flatten_bn(pixel_values, concat=True)
pixel_attention_mask = flatten_bn(pixel_attention_mask,
concat=True)
num_patches = flatten_bn(num_patches, concat=True)
return Idefics3ImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values),
pixel_attention_mask=pixel_attention_mask,
num_patches=num_patches,
embed_is_patch=embed_is_patch,
)
raise AssertionError("This line should be unreachable.")
def _process_image_pixels(
self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor:
pixel_values = inputs["pixel_values"]
pixel_attention_mask = inputs["pixel_attention_mask"]
return self.model.image_pixels_to_features(
pixel_values,
pixel_attention_mask=pixel_attention_mask,
)
def _process_image_input(
self,
image_input: ImageInputs,
) -> Union[torch.Tensor, list[torch.Tensor]]:
if image_input["type"] == "image_embeds":
return image_input["data"]
image_features = self._process_image_pixels(image_input)
image_features = self.model.connector(image_features)
num_patches = image_input["num_patches"]
return [
e.flatten(0, 1) for e in image_features.split(num_patches.tolist())
]
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self.model._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
vision_embeddings = self.model._process_image_input(image_input)
return vision_embeddings image_features = self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -632,8 +769,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -632,8 +769,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = self.model.get_input_embeddings(input_ids) inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids,
self.config.image_token_id) inputs_embeds,
select_patch_features(multimodal_embeddings),
self.config.image_token_id,
)
return inputs_embeds return inputs_embeds
def forward( def forward(
......
...@@ -415,6 +415,35 @@ def is_hybrid( ...@@ -415,6 +415,35 @@ def is_hybrid(
return isinstance(model, IsHybrid) return isinstance(model, IsHybrid)
@runtime_checkable
class HasNoOps(Protocol):
has_noops: ClassVar[Literal[True]] = True
@runtime_checkable
class _HasNoOpsType(Protocol):
has_noops: ClassVar[Literal[True]]
@overload
def has_noops(model: object) -> TypeIs[HasNoOps]:
...
@overload
def has_noops(model: Type[object]) -> TypeIs[Type[HasNoOps]]:
...
def has_noops(
model: Union[Type[object], object]
) -> Union[TypeIs[Type[HasNoOps]], TypeIs[HasNoOps]]:
if isinstance(model, type):
return isinstance(model, _HasNoOpsType)
return isinstance(model, HasNoOps)
@runtime_checkable @runtime_checkable
class SupportsCrossEncoding(Protocol): class SupportsCrossEncoding(Protocol):
"""The interface required for all models that support cross encoding.""" """The interface required for all models that support cross encoding."""
......
...@@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
...@@ -66,16 +65,13 @@ class InternVLImagePixelInputs(TypedDict): ...@@ -66,16 +65,13 @@ class InternVLImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond A boolean mask indicating which image embeddings correspond
to patch tokens. to patch tokens.
Shape: `(batch_size, num_images, num_embeds)` Shape: `(batch_size * num_images, num_embeds)`
""" """
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
class InternVLImageEmbeddingInputs(TypedDict): class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
data: NestedTensors data: Union[torch.Tensor, list[torch.Tensor]]
""" """
A tensor of shape `(num_images, total_image_feature_size, hidden_size)` A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)` or a list of tensors of shape `(total_image_feature_size, hidden_size)`
...@@ -426,7 +422,6 @@ class BaseInternVLProcessor(ABC): ...@@ -426,7 +422,6 @@ class BaseInternVLProcessor(ABC):
tokenizer = self.tokenizer tokenizer = self.tokenizer
image_token_id = self.image_token_id image_token_id = self.image_token_id
num_embeds = list[int]()
embed_is_patch = list[torch.Tensor]() embed_is_patch = list[torch.Tensor]()
for pixel_values in pixel_values_lst: for pixel_values in pixel_values_lst:
...@@ -438,11 +433,9 @@ class BaseInternVLProcessor(ABC): ...@@ -438,11 +433,9 @@ class BaseInternVLProcessor(ABC):
add_special_tokens=False) add_special_tokens=False)
text = [t.replace('<image>', image_repl.full, 1) for t in text] text = [t.replace('<image>', image_repl.full, 1) for t in text]
num_embeds.append(len(feature_tokens))
embed_is_patch.append( embed_is_patch.append(
torch.tensor(feature_tokens) == image_token_id) torch.tensor(feature_tokens) == image_token_id)
image_inputs["num_embeds"] = torch.tensor(num_embeds)
image_inputs["embed_is_patch"] = embed_is_patch image_inputs["embed_is_patch"] = embed_is_patch
text_inputs = self.tokenizer(text) text_inputs = self.tokenizer(text)
...@@ -607,7 +600,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -607,7 +600,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
"image", image_num_patches), "image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"), image_num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"), embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images), image_token_id=MultiModalFieldConfig.shared("image", num_images),
) )
...@@ -840,7 +832,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -840,7 +832,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values_flat = kwargs.pop("pixel_values_flat", None) pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None) image_num_patches = kwargs.pop("image_num_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None) embed_is_patch = kwargs.pop("embed_is_patch", None)
num_embeds = kwargs.pop("num_embeds", None)
image_embeds = kwargs.pop("image_embeds", None) image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None: if pixel_values_flat is None and image_embeds is None:
...@@ -873,12 +864,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -873,12 +864,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of embed_is_patch. " raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}") f"Got type: {type(embed_is_patch)}")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True) image_num_patches = flatten_bn(image_num_patches, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return InternVLImagePixelInputs( return InternVLImagePixelInputs(
type="pixel_values", type="pixel_values",
...@@ -886,7 +874,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -886,7 +874,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values_flat), pixel_values_flat),
num_patches=image_num_patches, num_patches=image_num_patches,
embed_is_patch=embed_is_patch, embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
...@@ -894,7 +881,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -894,7 +881,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def _process_image_input( def _process_image_input(
self, self,
image_input: InternVLImageInputs, image_input: InternVLImageInputs,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
return image_input["data"] return image_input["data"]
...@@ -934,16 +921,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -934,16 +921,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
image_features = self._process_image_input(image_input) image_features = self._process_image_input(image_input)
if (kwargs.get("v0_path", False) if image_input["type"] != "pixel_values":
or image_input["type"] != "pixel_values"):
return image_features return image_features
return flatten_2d_lists( return scatter_patch_features(
scatter_patch_features(*args) for args in zip( image_features,
image_features, image_input["embed_is_patch"],
image_input["num_embeds"], )
image_input["embed_is_patch"],
))
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -978,7 +962,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -978,7 +962,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings) vision_embeddings)
......
...@@ -73,6 +73,7 @@ class LlamaMLP(nn.Module): ...@@ -73,6 +73,7 @@ class LlamaMLP(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
prefix: str = "", prefix: str = "",
reduce_results: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
...@@ -87,6 +88,7 @@ class LlamaMLP(nn.Module): ...@@ -87,6 +88,7 @@ class LlamaMLP(nn.Module):
output_size=hidden_size, output_size=hidden_size,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
) )
if hidden_act != "silu": if hidden_act != "silu":
...@@ -628,10 +630,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -628,10 +630,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"ffn_norm": "post_attention_layernorm", "ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens", "tok_embeddings": "model.embed_tokens",
"output": "lm_head", "output": "lm_head",
"norm": "model.norm" "norm": "model.norm",
} }
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
...@@ -640,9 +646,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -640,9 +646,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config self.lora_config = lora_config
self.model = self._init_model(vllm_config=vllm_config, self.model = self._init_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"),
layer_type=layer_type)
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
...@@ -678,8 +683,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -678,8 +683,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): def _init_model(self,
return LlamaModel(vllm_config=vllm_config, prefix=prefix) vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
return LlamaModel(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
......
# SPDX-License-Identifier: Apache-2.0
#
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type
import torch
from torch import nn
from transformers import Llama4TextConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter)
class Llama4MoE(nn.Module):
@staticmethod
def custom_routing_function(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores, router_indices = torch.topk(gating_output, topk, dim=-1)
router_scores = torch.sigmoid(router_scores.float()).to(
hidden_states.dtype)
return (router_scores, router_indices.to(torch.int32))
def __init__(self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok
intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear(config.hidden_size,
config.num_local_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.router")
self.experts = FusedMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
custom_routing_function=Llama4MoE.custom_routing_function,
intermediate_size=intermediate_size_moe,
apply_router_weight_on_input=True,
reduce_results=False,
renormalize=False,
quant_config=quant_config,
prefix=f"{prefix}.experts")
self.shared_expert = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size_moe,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.shared_expert",
reduce_results=False, # We need to do scatter before reduce
)
def forward(self, hidden_states):
router_logits, _ = self.router(hidden_states)
shared_out = self.shared_expert(hidden_states)
routed_out = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
experts_out = routed_out + shared_out
if self.tp_size > 1:
experts_out = tensor_model_parallel_all_reduce(experts_out)
return experts_out
class Llama4Attention(nn.Module):
def __init__(self,
config: Llama4TextConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "") -> None:
super().__init__()
self.layer_idx = extract_layer_index(prefix)
self.hidden_size = hidden_size
self.no_rope_layers = config.no_rope_layers
self.nope = self.no_rope_layers[self.layer_idx] == 0
self.use_qk_norm = config.use_qk_norm and not self.nope
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
# TODO: attn_temperature_tuning should be a bool in huggingface
self.attn_temperature_tuning = self.nope and \
config.attn_temperature_tuning > 0
self.floor_scale = getattr(config, "floor_scale", 8192.0)
self.attn_scale = getattr(config, "attn_scale", 0.1)
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.n_rep = self.num_heads // self.num_kv_heads
self.q_norm = RMSNorm(
hidden_size=self.q_size,
eps=config.rms_norm_eps,
has_weight=False,
dtype=torch.float32,
) if self.use_qk_norm else None
self.k_norm = RMSNorm(
hidden_size=self.kv_size,
eps=config.rms_norm_eps,
has_weight=False,
dtype=torch.float32,
) if self.use_qk_norm else None
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias_o_proj,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
is_neox_style = True
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and config.model_type == "llama":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=int(rope_theta),
rope_scaling=rope_scaling if rope_scaling != "default" else None,
is_neox_style=is_neox_style,
) if not self.nope else None
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=None,
use_irope=not self.nope,
prefix=f"{prefix}.attn",
)
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale)
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
return attn_scale.unsqueeze(-1)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.rotary_emb is not None:
q, k = self.rotary_emb(positions, q, k)
if self.q_norm is not None:
q = self.q_norm(q.float()).to(q.dtype)
if self.k_norm is not None:
k = self.k_norm(k.float()).to(k.dtype)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
# to NoPE layers, where the inference-time temperature tuning function
# is customized to not affect short context
# while working at very long context
# https://arxiv.org/abs/2501.19399
#
# We should apply temperature tuning between (after) rotary / QK norm
# and (before) attention.
if self.attn_temperature_tuning and self.nope:
attn_scale = self._get_attn_scale(positions)
q = (q * attn_scale).to(q.dtype)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Llama4DecoderLayer(LlamaDecoderLayer):
def __init__(
self,
config: Llama4TextConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
self.layer_idx = extract_layer_index(prefix)
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling
max_position_embeddings = config.max_position_embeddings
self.self_attn = Llama4Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=False,
bias_o_proj=False,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
is_moe_layer = (self.layer_idx +
1) % config.interleave_moe_layer_step == 0
if is_moe_layer:
self.feed_forward = Llama4MoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
else:
self.feed_forward = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size_mlp,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.feed_forward",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
@support_torch_compile
class Llama4Model(LlamaModel):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[Llama4DecoderLayer] = Llama4DecoderLayer):
self.num_experts = vllm_config.model_config.hf_config.num_local_experts
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
def load_moe_expert_weights(
self,
name: str,
loaded_weight: torch.Tensor,
params_dict: Dict[str, nn.Parameter],
loaded_params: Set[str],
expert_params_mapping: List[Tuple[str, str, int, str]],
fused: bool = True,
) -> bool:
expert_param_loaded = False
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-1)
for (param_name, weight_name, expert_id,
shard_id) in expert_params_mapping:
new_loaded_weight = loaded_weight
if fused:
e_str, _, proj_str, _ = weight_name.split('.')
weight_name = f"{e_str}.{proj_str}"
param_name = f"{param_name}weight"
if weight_name not in name:
continue
full_param_name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[full_param_name]
weight_loader = param.weight_loader
if fused:
if "w13" in full_param_name:
shard_idx = 0 if shard_id == "w1" else 1
new_loaded_weight = new_loaded_weight[shard_idx]
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
layer_idx = extract_layer_index(name)
# EP mapping
expert_map = self.layers[
layer_idx].feed_forward.experts.expert_map
if expert_map is not None:
local_expert_indices = (expert_map != -1) \
.nonzero() \
.flatten() \
.to(new_loaded_weight.device)
new_loaded_weight = new_loaded_weight[local_expert_indices]
expert_id = local_expert_indices[0].item()
else:
# TODO: add EP support for non fused weights
pass
weight_loader(param,
new_loaded_weight,
full_param_name,
shard_id=shard_id,
expert_id=expert_id)
loaded_params.add(full_param_name)
expert_param_loaded = True
return expert_param_loaded
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
fused_experts_params = False
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.num_experts)
expert_params_mapping_fused = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="gate_up_proj",
num_experts=1)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
fused_experts_params = True
expert_params_mapping = expert_params_mapping_fused
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name or "experts" in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
moe_loaded = self.load_moe_expert_weights(
name,
loaded_weight,
params_dict,
loaded_params,
expert_params_mapping,
fused=fused_experts_params)
if not moe_loaded:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Llama4ForCausalLM(LlamaForCausalLM):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Update temperature tuning config from generation config
gen_config = vllm_config.model_config.try_get_generation_config()
gen_config.update(vllm_config.model_config.override_generation_config)
vllm_config.model_config.hf_config.attn_temperature_tuning \
= gen_config.get("attn_temperature_tuning", False)
LlamaForCausalLM.__init__(self,
vllm_config=vllm_config,
prefix=prefix,
layer_type=Llama4DecoderLayer)
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[Llama4DecoderLayer] = Llama4DecoderLayer):
return Llama4Model(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
weights = [
self.permute_qk_weight_for_rotary(name, loaded_weight)
for name, loaded_weight in weights
]
return loader.load_weights(weights)
def permute_qk_weight_for_rotary(
self,
name: str,
loaded_weight: torch.Tensor,
) -> Tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int):
attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size
return w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
modules = name.split(".")
# rotary embeds should be sliced
if ("wk" in modules or "k_proj" in modules) \
and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
elif ("wq" in modules or "q_proj" in modules) \
and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)
return name, loaded_weight
...@@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
...@@ -73,12 +72,9 @@ class PixtralHFImagePixelInputs(TypedDict): ...@@ -73,12 +72,9 @@ class PixtralHFImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond A boolean mask indicating which image embeddings correspond
to patch tokens. to patch tokens.
Shape: `(batch_size, num_images, num_embeds)` Shape: `(batch_size * num_images, num_embeds)`
""" """
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
class LlavaImageEmbeddingInputs(TypedDict): class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
...@@ -358,15 +354,10 @@ class PixtralHFMultiModalProcessor( ...@@ -358,15 +354,10 @@ class PixtralHFMultiModalProcessor(
image_height=pixel_value.shape[-2], image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"] ) for pixel_value in processed_outputs["pixel_values"]
] ]
num_embeds = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
# Each image may result to masks of different sizes, so we need to
# later use `num_embeds` to get per-image masks.
embed_is_patch = [ embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows) torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes for ncols, nrows in tile_sizes
] ]
processed_outputs["num_embeds"] = num_embeds
processed_outputs["embed_is_patch"] = embed_is_patch processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs return processed_outputs
...@@ -378,7 +369,6 @@ class PixtralHFMultiModalProcessor( ...@@ -378,7 +369,6 @@ class PixtralHFMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict( return dict(
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"), embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
...@@ -627,16 +617,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -627,16 +617,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of embed_is_patch. " raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}") f"Got type: {type(embed_is_patch)}")
num_embeds = kwargs.pop("num_embeds") embed_is_patch = flatten_bn(embed_is_patch)
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
return PixtralHFImagePixelInputs( return PixtralHFImagePixelInputs(
type="pixel_values_pixtral", type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values), pixel_values=flatten_bn(pixel_values),
embed_is_patch=embed_is_patch, embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
) )
return LlavaImagePixelInputs( return LlavaImagePixelInputs(
...@@ -728,19 +714,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -728,19 +714,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is None: if image_input is None:
return None return None
vision_embeddings = self._process_image_input(image_input) image_features = self._process_image_input(image_input)
if (kwargs.get("v0_path", False) if image_input["type"] != "pixel_values_pixtral":
or image_input["type"] != "pixel_values_pixtral"):
# The path is used for pixtral (V0 only) and llava (V0/V1) # The path is used for pixtral (V0 only) and llava (V0/V1)
return vision_embeddings return image_features
return flatten_2d_lists( return scatter_patch_features(
scatter_patch_features(*args) for args in zip( image_features,
vision_embeddings, image_input["embed_is_patch"],
image_input["num_embeds"], )
image_input["embed_is_patch"],
))
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -806,7 +789,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -806,7 +789,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings) vision_embeddings)
...@@ -886,6 +868,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -886,6 +868,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
mm_items = self._to_mm_items(mm_data) mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_all_counts() mm_item_counts = mm_items.get_all_counts()
mm_kwargs = result["mm_kwargs"] mm_kwargs = result["mm_kwargs"]
mm_hashes = result["mm_hashes"]
# We reimplement the functionality of MLlavaProcessor from # We reimplement the functionality of MLlavaProcessor from
# https://github.com/TIGER-AI-Lab/Mantis.git # https://github.com/TIGER-AI-Lab/Mantis.git
...@@ -934,6 +917,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -934,6 +917,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
prompt=prompt, prompt=prompt,
prompt_token_ids=prompt_ids, prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholder_ranges, mm_placeholders=mm_placeholder_ranges,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment