Commit fcfc474d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-dev

parents bb94d2e5 296c6572
......@@ -38,7 +38,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
......@@ -927,7 +928,11 @@ class ChameleonModel(nn.Module):
info=ChameleonProcessingInfo,
dummy_inputs=ChameleonDummyInputsBuilder)
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
SupportsPP, SupportsQuant):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
......@@ -12,6 +12,7 @@ import os
import re
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
......@@ -30,7 +31,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
......@@ -305,7 +306,12 @@ class GLMTransformer(nn.Module):
return hidden_states
class ChatGLMModel(nn.Module):
@support_torch_compile
class ChatGLMModel(nn.Module, SupportsQuant):
packed_modules_mapping = {
"linear_proj.merged_proj":
["linear_proj.gate_proj", "linear_proj.dense_h_to_4h"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......@@ -458,7 +464,6 @@ class ChatGLMModel(nn.Module):
class ChatGLMBaseModel(nn.Module):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={".word_embeddings": ""}, )
......@@ -516,7 +521,8 @@ class ChatGLMBaseModel(nn.Module):
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP):
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
SupportsQuant):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
......
......@@ -24,7 +24,6 @@
from typing import Iterable, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers import CohereConfig
......@@ -50,7 +49,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
from .utils import (extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
......@@ -333,7 +332,7 @@ class CohereModel(nn.Module):
return hidden_states
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
class ConstantSizeCache(ABC):
"""
Abstract base class for managing constant size caches
like Mamba and Minimax.
"""
def __init__(self, max_batch_size: int):
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the cache
self.cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size))
@property
@abstractmethod
def cache(self) -> Any:
"""Return the underlying cache tensor(s)"""
pass
@abstractmethod
def _copy_cache(self, from_index: int, to_index: int):
"""Copy cache data from one index to another"""
pass
def current_run_tensors(self, **kwargs) -> Tuple:
"""
Return the tensors for the current run's conv and ssm state.
"""
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
state_indices_tensor = torch.as_tensor(state_indices,
dtype=torch.int32,
device="cuda")
cache_tensors = self.cache
else:
# CUDA graph capturing runs
cache_tensors, state_indices_tensor = kwargs[
"seqlen_agnostic_capture_inputs"]
return (cache_tensors, state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
assert "seqlen_agnostic_capture_inputs" in input_buffers
_, input_state_indices_buffer = input_buffers[
"seqlen_agnostic_capture_inputs"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
state_indices)
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
input_state_indices_buffer.copy_(
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Cache during the CUDA graph replay
runs.
"""
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
return (self.cache, state_indices_tensor)
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
finished_requests_ids) -> int:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if cur_rid in finished_requests_ids:
# set as pad, do not allocate destination index
return PAD_SLOT_ID
elif cur_rid not in self.cache_indices_mapping:
destination_index = self.free_cache_indices.pop()
self.cache_indices_mapping[cur_rid] = {seq_id: destination_index}
return destination_index
elif seq_id not in (seq_ids2indices :=
self.cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists = next(iter(seq_ids2indices.values()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index = self.free_cache_indices.pop()
self._copy_cache(from_index=index_exists,
to_index=destination_index)
self.cache_indices_mapping[cur_rid][seq_id] = destination_index
return destination_index
else:
return self.cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]) -> List[int]:
return [
self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
def _release_finished_requests(self,
finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.cache_indices_mapping:
for seq_id in self.cache_indices_mapping[req_id]:
self.free_cache_indices.append(
self.cache_indices_mapping[req_id][seq_id])
self.cache_indices_mapping.pop(req_id)
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 DeciAI Research Team. All rights reserved.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on MistralAI GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeciLM model compatible with HuggingFace weights."""
from typing import Iterable, Set, Tuple
import torch
from vllm.config import VllmConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM
from .utils import is_pp_missing_parameter
class DeciLMForCausalLM(LlamaForCausalLM):
"""
Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct.
Based on the llama executor.
The main difference is that DeciLM uses Variable Grouped Query Attention.
The constant number of GQA heads in the decoder is overridden with a value
per layer.
Usually, in the HuggingFace implementation, instead of
"config.num_key_value_heads", we use
"config.num_key_value_heads_per_layer[i]" which varies.
Currently, PagedAttention does not work well with variable GQA, so we
normalize the weights upon loading, and use uniform GQA with the max value
instead.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer")
super().__init__(vllm_config=vllm_config)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "k_proj" in name or "v_proj" in name:
loaded_weight = self._degroup_weight(loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
hidden_size = self.config.hidden_size
head_size = self.config.hidden_size // self.config.num_attention_heads
target_num_kv_heads = self.config.num_key_value_heads
num_kv_heads = loaded_weight.shape[0] // head_size
n_repeats = target_num_kv_heads / num_kv_heads
assert n_repeats == int(n_repeats)
n_repeats = int(n_repeats)
loaded_weight = loaded_weight.view(num_kv_heads, head_size,
hidden_size)
loaded_weight = torch.repeat_interleave(loaded_weight,
repeats=n_repeats,
dim=0)
loaded_weight = loaded_weight.reshape(target_num_kv_heads * head_size,
hidden_size)
return loaded_weight
......@@ -509,7 +509,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
_, hw, n_dim = images_embeds.shape
h = w = int(hw**0.5)
# 根据self.tile_tag & self.global_view_pos填充image token sequence
# fill image token based on self.tile_tag & self.global_view_pos
tile_index = 0
vision_embeddings = []
for jdx in range(images_spatial_crop.size(0)):
......
......@@ -7,6 +7,7 @@ import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -59,7 +60,15 @@ class EAGLE(nn.Module):
truncated_vocab_size < vocab_size. To use this technique, one has to find
the top-k most frequent tokens in target dataset and add that as a tensor
in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute."""
needs to have truncated_vocab_size (=k) as an attribute.
4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP
module with regards to the use of additional RMS norms. The original
EAGLE architecture 1) skips the pre-attention norm in its first
transformer block, and 2) skips the final output norm, both of which we
found to be suboptimal. We also add the support for separate norms
applying to both the token embedding and hidden states before projection
as in DeepSeek MTP, which we found to improve performance as well.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......@@ -81,9 +90,22 @@ class EAGLE(nn.Module):
# While weights and biases are generally not needed,
# they are retained here to support certain unit tests
# (e.g., spec_decode/e2e/test_eagle_correctness.py).
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm(
weight=self.model.model.layers[0].input_layernorm.weight)
self.model.model.norm = DummyOutputNorm()
if not hasattr(self.config.model,
"skip_prenorm") or self.config.model.skip_prenorm:
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm(
weight=self.model.model.layers[0].input_layernorm.weight)
if not hasattr(
self.config.model,
"skip_output_norm") or self.config.model.skip_output_norm:
self.model.model.norm = DummyOutputNorm()
self.add_para_norm = False
if hasattr(self.config.model,
"add_para_norm") and self.config.model.add_para_norm:
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.add_para_norm = True
self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
......@@ -128,8 +150,17 @@ class EAGLE(nn.Module):
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
inputs_embeds = self.fc(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
if self.add_para_norm:
inputs_embeds = torch.cat([
self.enorm(inputs_embeds),
self.hnorm(previous_hidden_states)
],
dim=-1)
else:
inputs_embeds = torch.cat([inputs_embeds, previous_hidden_states],
dim=-1)
inputs_embeds = self.fc(inputs_embeds)
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
......@@ -190,6 +221,14 @@ class EAGLE(nn.Module):
else:
logger.warning_once("Found bias in the loaded weights but "
"the model config doesn't have bias.")
elif name.startswith("enorm.weight"):
weight_loader = getattr(self.enorm.weight, "weight_loader",
default_weight_loader)
weight_loader(self.enorm.weight, loaded_weight)
elif name.startswith("hnorm.weight"):
weight_loader = getattr(self.hnorm.weight, "weight_loader",
default_weight_loader)
weight_loader(self.hnorm.weight, loaded_weight)
elif name.startswith("model.lm_head.") or name.startswith(
"model.model."):
model_weights[name.split("model.", 1)[-1]] = loaded_weight
......
......@@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.exaone import ExaoneConfig
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
......@@ -313,6 +313,7 @@ class ExaoneModel(nn.Module):
lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
......@@ -384,6 +385,72 @@ class ExaoneModel(nn.Module):
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".c_fc_0", 0),
(".gate_up_proj", ".c_fc_1", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
......@@ -481,71 +548,12 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".c_fc_0", 0),
(".gate_up_proj", ".c_fc_1", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
loader = AutoWeightsLoader(
self,
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
......@@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import RWConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
......@@ -382,6 +382,17 @@ class FalconModel(nn.Module):
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings(input_ids)
......@@ -407,82 +418,6 @@ class FalconModel(nn.Module):
hidden_states = self.ln_f(hidden_states)
return hidden_states
class FalconForCausalLM(nn.Module, SupportsPP):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.transformer = FalconModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
# only Falcon-11B doesn't share lm_head weight with word embeddings
# and previous Falcon model doesn't have tie_word_embeddings config
# so we set tie_word_embeddings to True by default
self.tie_word_embeddings = (config.tie_word_embeddings
if config.tie_word_embeddings is not None
else True)
if self.tie_word_embeddings:
self.lm_head = self.transformer.word_embeddings
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions,
intermediate_tensors, inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
total_num_heads = self.config.num_attention_heads
......@@ -496,9 +431,6 @@ class FalconForCausalLM(nn.Module, SupportsPP):
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if name == "lm_head.weight" and self.tie_word_embeddings:
# Falcon uses tied embeddings except Falcon-11b.
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
......@@ -563,8 +495,78 @@ class FalconForCausalLM(nn.Module, SupportsPP):
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1)
return loaded_params
class FalconForCausalLM(nn.Module, SupportsPP):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.transformer = FalconModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
# only Falcon-11B doesn't share lm_head weight with word embeddings
# and previous Falcon model doesn't have tie_word_embeddings config
# so we set tie_word_embeddings to True by default
self.tie_word_embeddings = (config.tie_word_embeddings
if config.tie_word_embeddings is not None
else True)
if self.tie_word_embeddings:
self.lm_head = self.transformer.word_embeddings
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions,
intermediate_tensors, inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
......@@ -875,7 +875,8 @@ class Florence2MultiModalProcessor(
Florence2MultiModalProcessor,
info=Florence2ProcessingInfo,
dummy_inputs=Florence2DummyInputsBuilder)
class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal):
class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
......@@ -18,7 +18,7 @@
""" PyTorch Fuyu model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict
from typing import Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
......@@ -43,6 +43,7 @@ from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
......@@ -65,6 +66,14 @@ class FuyuImagePatchInputs(TypedDict):
flattened just like `flat_data`.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class FuyuProcessingInfo(BaseProcessingInfo):
......@@ -183,6 +192,19 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
processed_outputs["image_patches"] = image_patches[0]
# get patch grid size for each image
embed_is_patch = []
for image in images:
ncols, nrows = self.info.get_image_feature_grid_size(
image_width=image.width,
image_height=image.height,
)
mask = torch.tensor(([True] * ncols + [False]) * nrows)
embed_is_patch.append(mask)
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs
def _apply_hf_processor_tokens_only(
......@@ -202,7 +224,8 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(image_patches=MultiModalFieldConfig.batched("image"))
return dict(image_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates(
self,
......@@ -306,13 +329,20 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of image patches. "
f"Got type: {type(image_patches)}")
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
image_patches_flat = flatten_bn(image_patches)
embed_is_patch = flatten_bn(embed_is_patch)
return FuyuImagePatchInputs(
type="image_patches",
flat_data=self._validate_pixel_values(
flatten_bn(image_patches_flat, concat=True)),
patches_per_image=[x.size(0) for x in image_patches_flat],
embed_is_patch=embed_is_patch,
)
return None
......@@ -325,6 +355,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
assert self.vision_embed_tokens is not None
vision_embeddings_flat, _ = self.vision_embed_tokens(
image_patches_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_multimodal_embeddings(
......@@ -332,8 +363,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
image_features = self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
......@@ -343,8 +379,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
_IMAGE_TOKEN_ID)
input_ids, inputs_embeds,
select_patch_features(multimodal_embeddings), _IMAGE_TOKEN_ID)
return inputs_embeds
def forward(
......
......@@ -59,16 +59,23 @@ class Gemma3MLP(nn.Module):
intermediate_size: int,
hidden_activation: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_activation != "gelu_pytorch_tanh":
raise ValueError(
"Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
......@@ -125,12 +132,14 @@ class Gemma3Attention(nn.Module):
self.total_num_kv_heads,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
......@@ -293,6 +302,7 @@ class Gemma3DecoderLayer(nn.Module):
intermediate_size=config.intermediate_size,
hidden_activation=config.hidden_activation,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
......@@ -344,6 +354,7 @@ class Gemma3Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=f"{prefix}.embed_tokens",
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
......
......@@ -30,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
......@@ -60,12 +59,9 @@ class Gemma3ImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
Shape: `(batch_size * num_images, num_embeds)`
"""
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
Gemma3ImageInputs = Gemma3ImagePixelInputs
......@@ -295,8 +291,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None:
assert isinstance(images, list)
parsed_images = (self._get_data_parser().parse_mm_data({
"image":
images
......@@ -319,11 +313,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
num_embeds = [
len(image_repl_feature_tokens)
for image_repl_feature_tokens in image_repls_feature_tokens
]
processed_outputs["num_embeds"] = torch.tensor(num_embeds)
vocab = tokenizer.get_vocab()
image_token_id = vocab[tokenizer.image_token]
......@@ -356,7 +345,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
"image", num_crops + 1),
num_crops=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
......@@ -585,7 +573,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
pixel_values = kwargs.pop("pixel_values", None)
num_crops = kwargs.pop("num_crops", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
num_embeds = kwargs.pop("num_embeds", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
......@@ -603,19 +590,15 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return Gemma3ImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values),
num_patches=num_crops + 1,
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
def _image_pixels_to_features(
......@@ -630,7 +613,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def _process_image_input(
self,
image_input: Gemma3ImageInputs,
) -> tuple[torch.Tensor, ...]:
) -> list[torch.Tensor]:
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"]
......@@ -642,7 +625,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
)
image_embeds = self.multi_modal_projector(image_features)
return image_embeds.split(num_patches.tolist())
return [
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
]
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
......@@ -652,15 +637,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
image_features = self._process_image_input(image_input)
if kwargs.get("v0_path", False):
return image_features
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
......@@ -689,7 +669,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
......
......@@ -44,7 +44,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from vllm.utils import is_hip,W8a8GetCacheJSON
......@@ -219,6 +219,13 @@ class GPTNeoXModel(nn.Module):
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.tritonsingleton= W8a8GetCacheJSON()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_in(input_ids)
......@@ -244,67 +251,6 @@ class GPTNeoXModel(nn.Module):
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "gpt_neox"))
self.embed_out = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.embed_out.weight = self.gpt_neox.embed_in.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.gpt_neox.make_empty_intermediate_tensors)
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.tritonsingleton= W8a8GetCacheJSON()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.gpt_neox.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.gpt_neox(input_ids, positions,
intermediate_tensors, inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.embed_out, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters())
......@@ -342,6 +288,7 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
#当为triton支持推理的时候不能进行处理
if self.quant_method == "compressed_tensors":
os.environ['LM_NN'] = '0'
......@@ -387,6 +334,64 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
n=int(key.split('_')[1])
k=int(key.split('_')[2])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
return loaded_params
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "gpt_neox"))
self.embed_out = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.embed_out.weight = self.gpt_neox.embed_in.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.gpt_neox.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.gpt_neox.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.gpt_neox(input_ids, positions,
intermediate_tensors, inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.embed_out, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -17,16 +17,14 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Dict, List, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Dict, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
Idefics3Processor)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
......@@ -35,13 +33,16 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.parse import ImageProcessorItems, ImageSize
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalDataItems,
MultiModalFieldConfig,
PromptReplacement, PromptUpdate)
PromptReplacement, PromptUpdate,
encode_tokens)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
......@@ -53,18 +54,28 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
logger = init_logger(__name__)
from .vision import scatter_patch_features, select_patch_features
class Idefics3ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
pixel_values: torch.Tensor
"""
Shape: `(batch_size * num_images * num_patches,
num_channels, height, width)`
"""
pixel_attention_mask: Optional[torch.BoolTensor]
pixel_attention_mask: torch.Tensor
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class Idefics3ImageEmbeddingInputs(TypedDict):
......@@ -75,6 +86,14 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
......@@ -100,32 +119,14 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
hf_processor = self.get_hf_processor()
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
grid_w, grid_h = self._get_image_feature_grid_size(
image_width=image_processor.size['longest_edge'],
image_height=image_processor.size['longest_edge'],
)
num_image_token = (grid_w * grid_h + 1) * hf_processor.image_seq_len
# Calculate Non-image-token length
# NOTE: <row_1_col_1> and <global-img> are special token for SmolVLM
# but not for Idefic3, so we need to tokenize them to get actual length.
tokenizer = self.get_tokenizer()
tile_token_len = len(tokenizer.tokenize("<row_1_col_1>"))
glob_token_len = len(tokenizer.tokenize(hf_processor.global_image_tag))
# linebreak and <fake_token_around_image> always cost 1 token
fake_token_len = lb_len = 1
non_image_token = (grid_w * grid_h) * (
tile_token_len + fake_token_len) + glob_token_len + (
grid_h + 1) * lb_len + fake_token_len
return {"image": num_image_token + non_image_token}
return {"image": self.get_max_image_tokens()}
def _resize_output_size(self,
*,
height: int,
width: int,
max_len: Optional[int] = None,
min_len: Optional[int] = 1,
min_len: int = 1,
max_size: Optional[int] = None) -> tuple[int, int]:
# Set default value for max_len if not provided
max_len = max(height, width) if max_len is None else max_len
......@@ -181,10 +182,13 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
size: Optional[dict[str, object]] = None,
processor: Optional[Idefics3Processor],
) -> tuple[int, int]:
hf_processor = self.get_hf_processor(size=size)
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
if processor is None:
processor = self.get_hf_processor()
image_processor: Idefics3ImageProcessor = processor.image_processor
max_image_size = image_processor.max_image_size['longest_edge']
size = image_processor.size['longest_edge']
assert size % max_image_size == 0, (
......@@ -204,6 +208,105 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
grid_h = grid_w = 0
return grid_w, grid_h
def get_num_patches(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Idefics3Processor],
) -> int:
grid_w, grid_h = self._get_image_feature_grid_size(
image_width=image_width,
image_height=image_height,
processor=processor,
)
return grid_w * grid_h + 1
def get_image_repl(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Idefics3Processor],
) -> str:
if processor is None:
processor = self.get_hf_processor()
image_token = processor.image_token.content
fake_image_token = processor.fake_image_token.content
global_img_token = processor.global_image_tag
image_seq_len = processor.image_seq_len
grid_placeholder = "<row_{n_h}_col_{n_w}>"
p_img = image_token * image_seq_len
global_img_placeholder = fake_image_token + global_img_token + p_img
tile_img_placeholder = fake_image_token + grid_placeholder + p_img
grid_w, grid_h = self._get_image_feature_grid_size(
image_width=image_width,
image_height=image_height,
processor=processor,
)
if grid_w == 0 and grid_h == 0:
return global_img_placeholder + fake_image_token
tiles_placeholder = list[str]()
for i in range(grid_h):
for j in range(grid_w):
placeholder_per_tile = tile_img_placeholder.format(n_h=i + 1,
n_w=j + 1)
tiles_placeholder.append(placeholder_per_tile)
# Add line break if it is the last tile in the row
if j == grid_w - 1:
tiles_placeholder.append("\n")
return "".join([
*tiles_placeholder,
"\n",
global_img_placeholder,
fake_image_token,
])
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Idefics3Processor],
) -> int:
tokenizer = self.get_tokenizer()
image_repl = self.get_image_repl(
image_width=image_width,
image_height=image_height,
processor=processor,
)
image_repl_tokens = encode_tokens(
tokenizer,
image_repl,
add_special_tokens=False,
)
return len(image_repl_tokens)
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
image_processor: Idefics3ImageProcessor = processor.image_processor
return ImageSize(
width=image_processor.size["longest_edge"],
height=image_processor.size["longest_edge"],
)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
):
......@@ -217,7 +320,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
hf_processor = self.info.get_hf_processor()
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
longest_edge = image_processor.max_image_size['longest_edge']
image_token: str = hf_processor.image_token.content
image_token = hf_processor.image_token.content
mm_data = {
"image":
......@@ -232,7 +335,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
)
class Idefics3MultimodalProcessor(
class Idefics3MultiModalProcessor(
BaseMultiModalProcessor[Idefics3ProcessingInfo]):
def _call_hf_processor(
......@@ -241,26 +344,61 @@ class Idefics3MultimodalProcessor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_kwargs)
image_grids = [
self.info._get_image_feature_grid_size(
image_width=img.width,
image_height=img.height,
**mm_kwargs,
) for img in mm_data["images"]
]
image_patches = list(map(lambda x: math.prod(x) + 1, image_grids))
for key in ("pixel_values", "pixel_attention_mask"):
data = processed_outputs.pop(key)
data = data.flatten(0, 1).split(image_patches)
processed_outputs[key] = data
else:
tokenizer = self.info.get_tokenizer()
processed_outputs = tokenizer(prompt,
add_special_tokens=True,
return_tensors="pt")
# Text-only input not supported in composite processor
if not (images := mm_data.get("images", [])):
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
processed_outputs = super()._call_hf_processor(
prompt,
mm_data,
mm_kwargs,
)
parsed_images = (self._get_data_parser().parse_mm_data({
"image": images
}).get_items("image", ImageProcessorItems))
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_repl_features = [
self.info.get_image_repl(image_width=size.width,
image_height=size.height,
processor=hf_processor)
for size in image_sizes
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
vocab = tokenizer.get_vocab()
image_token_id = vocab[hf_processor.image_token.content]
embed_is_patch = [
torch.tensor(image_repl_tokens) == image_token_id
for image_repl_tokens in image_repls_feature_tokens
]
processed_outputs["embed_is_patch"] = embed_is_patch
num_patches = [
self.info.get_num_patches(
image_width=size.width,
image_height=size.height,
processor=hf_processor,
) for size in image_sizes
]
processed_outputs["num_patches"] = torch.tensor(num_patches)
# Remove the extra batch dimension
processed_outputs["pixel_values"].squeeze_(0)
processed_outputs["pixel_attention_mask"].squeeze_(0)
return processed_outputs
def _get_mm_fields_config(
......@@ -268,10 +406,16 @@ class Idefics3MultimodalProcessor(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_patches = hf_inputs.get("num_patches", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
pixel_attention_mask=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches),
pixel_attention_mask=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches),
image_embeds=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
......@@ -281,42 +425,18 @@ class Idefics3MultimodalProcessor(
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token.content
fake_image_token = hf_processor.fake_image_token.content
global_img_token = hf_processor.global_image_tag
image_seq_len = hf_processor.image_seq_len
grid_placeholder = "<row_{n_h}_col_{n_w}>"
p_img = image_token * image_seq_len
global_img_placeholder = fake_image_token + global_img_token + p_img
tile_img_placeholder = fake_image_token + grid_placeholder + p_img
def get_replacement_idefics3(item_idx: int) -> str:
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
grid_w, grid_h = self.info._get_image_feature_grid_size(
return self.info.get_image_repl(
image_width=image_size.width,
image_height=image_size.height,
**hf_processor_mm_kwargs,
processor=hf_processor,
)
if grid_w == 0 and grid_h == 0:
image_placeholder = global_img_placeholder
else:
tiles_placeholder = list[str]()
for i in range(grid_h):
for j in range(grid_w):
placeholder_per_tile = tile_img_placeholder.format(
n_h=i + 1, n_w=j + 1)
tiles_placeholder.append(placeholder_per_tile)
# Add line break if it is the last tile in the row
if j == grid_w - 1:
tiles_placeholder.append("\n")
image_placeholder = "".join(
[*tiles_placeholder, "\n", global_img_placeholder])
return image_placeholder + fake_image_token
return [
PromptReplacement(
......@@ -424,73 +544,13 @@ class Idefics3Model(nn.Module):
config.vision_config.patch_size)**2) / (config.scale_factor**2))
self.image_token_id = self.config.image_token_id
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
pixel_attention_mask = kwargs.pop("pixel_attention_mask", None)
if pixel_values is None and image_embeds is None:
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Idefics3ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if isinstance(pixel_values, list):
pixel_values = torch.cat(pixel_values, dim=1)
pixel_attention_mask = torch.cat(pixel_attention_mask, dim=1)
else:
pixel_values = flatten_bn(pixel_values)
pixel_attention_mask = flatten_bn(pixel_attention_mask)
return Idefics3ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
pixel_attention_mask=pixel_attention_mask)
raise AssertionError("This line should be unreachable.")
def _image_pixels_to_features(
def image_pixels_to_features(
self,
pixel_values: torch.Tensor,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
) -> NestedTensors:
pixel_attention_mask: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
num_patches = [x.size(0) for x in pixel_values]
pixel_values = pixel_values.to(
dtype=self.vision_model.embeddings.patch_embedding.weight.dtype
) # fp16 compatibility
......@@ -502,17 +562,9 @@ class Idefics3Model(nn.Module):
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(pixel_values.size(0), pixel_values.size(2),
pixel_values.size(3)),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask
pixel_attention_mask = pixel_attention_mask[
real_images_inds].contiguous()
# Remove padding images from the mask
pixel_attention_mask = pixel_attention_mask[
real_images_inds].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(dimension=1,
......@@ -529,27 +581,7 @@ class Idefics3Model(nn.Module):
patch_attention_mask=patch_attention_mask,
)
return image_hidden_states.split(num_patches)
def _process_image_pixels(
self, inputs: Idefics3ImagePixelInputs) -> NestedTensors:
assert self.vision_model is not None
pixel_values = inputs["data"]
pixel_attention_mask = inputs["pixel_attention_mask"]
return self._image_pixels_to_features(pixel_values,
pixel_attention_mask)
def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_features = self._process_image_pixels(image_input)
num_patches = [x.size(0) for x in image_features]
image_features = torch.cat(image_features)
return self.connector(image_features).split(num_patches)
return image_hidden_states
def get_input_embeddings(
self,
......@@ -575,7 +607,7 @@ class Idefics3Model(nn.Module):
@MULTIMODAL_REGISTRY.register_processor(
Idefics3MultimodalProcessor,
Idefics3MultiModalProcessor,
info=Idefics3ProcessingInfo,
dummy_inputs=Idefics3DummyInputsBuilder)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
......@@ -616,13 +648,118 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
self.sampler = get_sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Idefics3ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
embed_is_patch=embed_is_patch,
)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
pixel_attention_mask = kwargs.pop("pixel_attention_mask")
if not isinstance(pixel_attention_mask, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel_attention_mask. "
f"Got type: {type(pixel_attention_mask)}")
num_patches = kwargs.pop("num_patches")
if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
pixel_values = flatten_bn(pixel_values, concat=True)
pixel_attention_mask = flatten_bn(pixel_attention_mask,
concat=True)
num_patches = flatten_bn(num_patches, concat=True)
return Idefics3ImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values),
pixel_attention_mask=pixel_attention_mask,
num_patches=num_patches,
embed_is_patch=embed_is_patch,
)
raise AssertionError("This line should be unreachable.")
def _process_image_pixels(
self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor:
pixel_values = inputs["pixel_values"]
pixel_attention_mask = inputs["pixel_attention_mask"]
return self.model.image_pixels_to_features(
pixel_values,
pixel_attention_mask=pixel_attention_mask,
)
def _process_image_input(
self,
image_input: ImageInputs,
) -> Union[torch.Tensor, list[torch.Tensor]]:
if image_input["type"] == "image_embeds":
return image_input["data"]
image_features = self._process_image_pixels(image_input)
image_features = self.model.connector(image_features)
num_patches = image_input["num_patches"]
return [
e.flatten(0, 1) for e in image_features.split(num_patches.tolist())
]
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self.model._parse_and_validate_image_input(**kwargs)
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self.model._process_image_input(image_input)
return vision_embeddings
image_features = self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
......@@ -632,8 +769,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_id)
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
self.config.image_token_id,
)
return inputs_embeds
def forward(
......
......@@ -415,6 +415,35 @@ def is_hybrid(
return isinstance(model, IsHybrid)
@runtime_checkable
class HasNoOps(Protocol):
has_noops: ClassVar[Literal[True]] = True
@runtime_checkable
class _HasNoOpsType(Protocol):
has_noops: ClassVar[Literal[True]]
@overload
def has_noops(model: object) -> TypeIs[HasNoOps]:
...
@overload
def has_noops(model: Type[object]) -> TypeIs[Type[HasNoOps]]:
...
def has_noops(
model: Union[Type[object], object]
) -> Union[TypeIs[Type[HasNoOps]], TypeIs[HasNoOps]]:
if isinstance(model, type):
return isinstance(model, _HasNoOpsType)
return isinstance(model, HasNoOps)
@runtime_checkable
class SupportsCrossEncoding(Protocol):
"""The interface required for all models that support cross encoding."""
......
......@@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
......@@ -66,16 +65,13 @@ class InternVLImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
Shape: `(batch_size * num_images, num_embeds)`
"""
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: NestedTensors
data: Union[torch.Tensor, list[torch.Tensor]]
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
......@@ -426,7 +422,6 @@ class BaseInternVLProcessor(ABC):
tokenizer = self.tokenizer
image_token_id = self.image_token_id
num_embeds = list[int]()
embed_is_patch = list[torch.Tensor]()
for pixel_values in pixel_values_lst:
......@@ -438,11 +433,9 @@ class BaseInternVLProcessor(ABC):
add_special_tokens=False)
text = [t.replace('<image>', image_repl.full, 1) for t in text]
num_embeds.append(len(feature_tokens))
embed_is_patch.append(
torch.tensor(feature_tokens) == image_token_id)
image_inputs["num_embeds"] = torch.tensor(num_embeds)
image_inputs["embed_is_patch"] = embed_is_patch
text_inputs = self.tokenizer(text)
......@@ -607,7 +600,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
"image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
......@@ -840,7 +832,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
num_embeds = kwargs.pop("num_embeds", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None:
......@@ -873,12 +864,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return InternVLImagePixelInputs(
type="pixel_values",
......@@ -886,7 +874,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values_flat),
num_patches=image_num_patches,
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
raise AssertionError("This line should be unreachable.")
......@@ -894,7 +881,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds":
return image_input["data"]
......@@ -934,16 +921,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
image_features = self._process_image_input(image_input)
if (kwargs.get("v0_path", False)
or image_input["type"] != "pixel_values"):
if image_input["type"] != "pixel_values":
return image_features
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
......@@ -978,7 +962,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
......
......@@ -73,6 +73,7 @@ class LlamaMLP(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
......@@ -87,6 +88,7 @@ class LlamaMLP(nn.Module):
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
......@@ -628,10 +630,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm"
"norm": "model.norm",
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
......@@ -640,9 +646,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config
self.model = self._init_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
prefix=maybe_prefix(prefix, "model"),
layer_type=layer_type)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
......@@ -678,8 +683,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors)
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return LlamaModel(vllm_config=vllm_config, prefix=prefix)
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
return LlamaModel(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
......
# SPDX-License-Identifier: Apache-2.0
#
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type
import torch
from torch import nn
from transformers import Llama4TextConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter)
class Llama4MoE(nn.Module):
@staticmethod
def custom_routing_function(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores, router_indices = torch.topk(gating_output, topk, dim=-1)
router_scores = torch.sigmoid(router_scores.float()).to(
hidden_states.dtype)
return (router_scores, router_indices.to(torch.int32))
def __init__(self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok
intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear(config.hidden_size,
config.num_local_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.router")
self.experts = FusedMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
custom_routing_function=Llama4MoE.custom_routing_function,
intermediate_size=intermediate_size_moe,
apply_router_weight_on_input=True,
reduce_results=False,
renormalize=False,
quant_config=quant_config,
prefix=f"{prefix}.experts")
self.shared_expert = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size_moe,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.shared_expert",
reduce_results=False, # We need to do scatter before reduce
)
def forward(self, hidden_states):
router_logits, _ = self.router(hidden_states)
shared_out = self.shared_expert(hidden_states)
routed_out = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
experts_out = routed_out + shared_out
if self.tp_size > 1:
experts_out = tensor_model_parallel_all_reduce(experts_out)
return experts_out
class Llama4Attention(nn.Module):
def __init__(self,
config: Llama4TextConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "") -> None:
super().__init__()
self.layer_idx = extract_layer_index(prefix)
self.hidden_size = hidden_size
self.no_rope_layers = config.no_rope_layers
self.nope = self.no_rope_layers[self.layer_idx] == 0
self.use_qk_norm = config.use_qk_norm and not self.nope
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
# TODO: attn_temperature_tuning should be a bool in huggingface
self.attn_temperature_tuning = self.nope and \
config.attn_temperature_tuning > 0
self.floor_scale = getattr(config, "floor_scale", 8192.0)
self.attn_scale = getattr(config, "attn_scale", 0.1)
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.n_rep = self.num_heads // self.num_kv_heads
self.q_norm = RMSNorm(
hidden_size=self.q_size,
eps=config.rms_norm_eps,
has_weight=False,
dtype=torch.float32,
) if self.use_qk_norm else None
self.k_norm = RMSNorm(
hidden_size=self.kv_size,
eps=config.rms_norm_eps,
has_weight=False,
dtype=torch.float32,
) if self.use_qk_norm else None
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias_o_proj,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
is_neox_style = True
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and config.model_type == "llama":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=int(rope_theta),
rope_scaling=rope_scaling if rope_scaling != "default" else None,
is_neox_style=is_neox_style,
) if not self.nope else None
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=None,
use_irope=not self.nope,
prefix=f"{prefix}.attn",
)
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale)
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
return attn_scale.unsqueeze(-1)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.rotary_emb is not None:
q, k = self.rotary_emb(positions, q, k)
if self.q_norm is not None:
q = self.q_norm(q.float()).to(q.dtype)
if self.k_norm is not None:
k = self.k_norm(k.float()).to(k.dtype)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
# to NoPE layers, where the inference-time temperature tuning function
# is customized to not affect short context
# while working at very long context
# https://arxiv.org/abs/2501.19399
#
# We should apply temperature tuning between (after) rotary / QK norm
# and (before) attention.
if self.attn_temperature_tuning and self.nope:
attn_scale = self._get_attn_scale(positions)
q = (q * attn_scale).to(q.dtype)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Llama4DecoderLayer(LlamaDecoderLayer):
def __init__(
self,
config: Llama4TextConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
self.layer_idx = extract_layer_index(prefix)
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling
max_position_embeddings = config.max_position_embeddings
self.self_attn = Llama4Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=False,
bias_o_proj=False,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
is_moe_layer = (self.layer_idx +
1) % config.interleave_moe_layer_step == 0
if is_moe_layer:
self.feed_forward = Llama4MoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
else:
self.feed_forward = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size_mlp,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.feed_forward",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
@support_torch_compile
class Llama4Model(LlamaModel):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[Llama4DecoderLayer] = Llama4DecoderLayer):
self.num_experts = vllm_config.model_config.hf_config.num_local_experts
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
def load_moe_expert_weights(
self,
name: str,
loaded_weight: torch.Tensor,
params_dict: Dict[str, nn.Parameter],
loaded_params: Set[str],
expert_params_mapping: List[Tuple[str, str, int, str]],
fused: bool = True,
) -> bool:
expert_param_loaded = False
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-1)
for (param_name, weight_name, expert_id,
shard_id) in expert_params_mapping:
new_loaded_weight = loaded_weight
if fused:
e_str, _, proj_str, _ = weight_name.split('.')
weight_name = f"{e_str}.{proj_str}"
param_name = f"{param_name}weight"
if weight_name not in name:
continue
full_param_name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[full_param_name]
weight_loader = param.weight_loader
if fused:
if "w13" in full_param_name:
shard_idx = 0 if shard_id == "w1" else 1
new_loaded_weight = new_loaded_weight[shard_idx]
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
layer_idx = extract_layer_index(name)
# EP mapping
expert_map = self.layers[
layer_idx].feed_forward.experts.expert_map
if expert_map is not None:
local_expert_indices = (expert_map != -1) \
.nonzero() \
.flatten() \
.to(new_loaded_weight.device)
new_loaded_weight = new_loaded_weight[local_expert_indices]
expert_id = local_expert_indices[0].item()
else:
# TODO: add EP support for non fused weights
pass
weight_loader(param,
new_loaded_weight,
full_param_name,
shard_id=shard_id,
expert_id=expert_id)
loaded_params.add(full_param_name)
expert_param_loaded = True
return expert_param_loaded
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
fused_experts_params = False
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.num_experts)
expert_params_mapping_fused = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="gate_up_proj",
num_experts=1)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
fused_experts_params = True
expert_params_mapping = expert_params_mapping_fused
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name or "experts" in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
moe_loaded = self.load_moe_expert_weights(
name,
loaded_weight,
params_dict,
loaded_params,
expert_params_mapping,
fused=fused_experts_params)
if not moe_loaded:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Llama4ForCausalLM(LlamaForCausalLM):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Update temperature tuning config from generation config
gen_config = vllm_config.model_config.try_get_generation_config()
gen_config.update(vllm_config.model_config.override_generation_config)
vllm_config.model_config.hf_config.attn_temperature_tuning \
= gen_config.get("attn_temperature_tuning", False)
LlamaForCausalLM.__init__(self,
vllm_config=vllm_config,
prefix=prefix,
layer_type=Llama4DecoderLayer)
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[Llama4DecoderLayer] = Llama4DecoderLayer):
return Llama4Model(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
weights = [
self.permute_qk_weight_for_rotary(name, loaded_weight)
for name, loaded_weight in weights
]
return loader.load_weights(weights)
def permute_qk_weight_for_rotary(
self,
name: str,
loaded_weight: torch.Tensor,
) -> Tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int):
attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size
return w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
modules = name.split(".")
# rotary embeds should be sliced
if ("wk" in modules or "k_proj" in modules) \
and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
elif ("wq" in modules or "q_proj" in modules) \
and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)
return name, loaded_weight
......@@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
......@@ -73,12 +72,9 @@ class PixtralHFImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
Shape: `(batch_size * num_images, num_embeds)`
"""
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
......@@ -358,15 +354,10 @@ class PixtralHFMultiModalProcessor(
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
num_embeds = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
# Each image may result to masks of different sizes, so we need to
# later use `num_embeds` to get per-image masks.
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["num_embeds"] = num_embeds
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs
......@@ -378,7 +369,6 @@ class PixtralHFMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
......@@ -627,16 +617,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_embeds = kwargs.pop("num_embeds")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
embed_is_patch = flatten_bn(embed_is_patch)
return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
return LlavaImagePixelInputs(
......@@ -728,19 +714,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
image_features = self._process_image_input(image_input)
if (kwargs.get("v0_path", False)
or image_input["type"] != "pixel_values_pixtral"):
if image_input["type"] != "pixel_values_pixtral":
# The path is used for pixtral (V0 only) and llava (V0/V1)
return vision_embeddings
return image_features
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
vision_embeddings,
image_input["num_embeds"],
image_input["embed_is_patch"],
))
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
......@@ -806,7 +789,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
......@@ -886,6 +868,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_all_counts()
mm_kwargs = result["mm_kwargs"]
mm_hashes = result["mm_hashes"]
# We reimplement the functionality of MLlavaProcessor from
# https://github.com/TIGER-AI-Lab/Mantis.git
......@@ -934,6 +917,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
prompt=prompt,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholder_ranges,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment