Commit ec5e299c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.3' into v0.7.3-dev

parents 47bd229c ed6e9075
...@@ -28,13 +28,13 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ...@@ -28,13 +28,13 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
MlpProjectorConfig, MlpProjectorConfig,
VisionEncoderConfig) VisionEncoderConfig)
from vllm.transformers_utils.processors.deepseek_vl2 import ( from vllm.transformers_utils.processors.deepseek_vl2 import (
DeepseekVLV2Processor) DeepseekVLV2Processor)
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
...@@ -133,8 +133,8 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo): ...@@ -133,8 +133,8 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(DeepseekVLV2Config) return self.ctx.get_hf_config(DeepseekVLV2Config)
def get_hf_processor(self) -> DeepseekVLV2Processor: def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(DeepseekVLV2Processor) return self.ctx.get_hf_processor(DeepseekVLV2Processor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
...@@ -308,13 +308,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -308,13 +308,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.text_config = config.text_config self.text_config = config.text_config
model_config = vllm_config.model_config model_config = vllm_config.model_config
tokenizer = cached_get_tokenizer( tokenizer = cached_tokenizer_from_config(model_config)
model_config.tokenizer, self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]
tokenizer_mode=model_config.tokenizer_mode,
tokenizer_revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
)
self.image_token_id = tokenizer.vocab.get(_IMAGE_TOKEN)
self.vision = self._init_vision_module(self.vision_config, self.vision = self._init_vision_module(self.vision_config,
quant_config, quant_config,
......
...@@ -7,6 +7,7 @@ import torch.nn as nn ...@@ -7,6 +7,7 @@ import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -18,6 +19,8 @@ from vllm.sequence import IntermediateTensors ...@@ -18,6 +19,8 @@ from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix from .utils import maybe_prefix
logger = init_logger(__name__)
class DummyInputLayerNorm(nn.Module): class DummyInputLayerNorm(nn.Module):
...@@ -190,8 +193,8 @@ class EAGLE(nn.Module): ...@@ -190,8 +193,8 @@ class EAGLE(nn.Module):
default_weight_loader) default_weight_loader)
weight_loader(self.fc.bias, loaded_weight) weight_loader(self.fc.bias, loaded_weight)
else: else:
raise ValueError("Found bias in the loaded weights " logger.warning_once("Found bias in the loaded weights but "
"but the model config doesn't have bias") "the model config doesn't have bias.")
elif name.startswith("model.lm_head.") or name.startswith( elif name.startswith("model.lm_head.") or name.startswith(
"model.model."): "model.model."):
model_weights[name.split("model.", 1)[-1]] = loaded_weight model_weights[name.split("model.", 1)[-1]] = loaded_weight
...@@ -200,12 +203,21 @@ class EAGLE(nn.Module): ...@@ -200,12 +203,21 @@ class EAGLE(nn.Module):
else: else:
model_weights[f"model.{name}"] = loaded_weight model_weights[f"model.{name}"] = loaded_weight
lm_head_weight = model_weights.pop("lm_head.weight") if "lm_head.weight" in model_weights:
lm_head_weight = model_weights.pop("lm_head.weight")
if self.token_map is not None and\
lm_head_weight.shape[0] > self.token_map.shape[0]:
if self.token_map is not None and\ lm_head_weight = lm_head_weight[self.token_map]
lm_head_weight.shape[0] > self.token_map.shape[0]:
lm_head_weight = lm_head_weight[self.token_map] else:
# NOTE(Shangming): initialize the placeholder for lm_head weight.
lm_head_weight = torch.zeros(
self.lm_head.org_vocab_size,
self.lm_head.embedding_dim,
dtype=self.config.torch_dtype,
)
weight_loader = getattr(self.lm_head.weight, "weight_loader", weight_loader = getattr(self.lm_head.weight, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -71,8 +71,8 @@ class FuyuProcessingInfo(BaseProcessingInfo): ...@@ -71,8 +71,8 @@ class FuyuProcessingInfo(BaseProcessingInfo):
def get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(FuyuConfig) return self.ctx.get_hf_config(FuyuConfig)
def get_hf_processor(self): def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(FuyuProcessor) return self.ctx.get_hf_processor(FuyuProcessor, **kwargs)
def get_image_processor(self) -> FuyuImageProcessor: def get_image_processor(self) -> FuyuImageProcessor:
return self.get_hf_processor().image_processor return self.get_hf_processor().image_processor
...@@ -104,6 +104,8 @@ class FuyuProcessingInfo(BaseProcessingInfo): ...@@ -104,6 +104,8 @@ class FuyuProcessingInfo(BaseProcessingInfo):
image_processor = self.get_image_processor() image_processor = self.get_image_processor()
target_width = image_processor.size["width"] target_width = image_processor.size["width"]
target_height = image_processor.size["height"] target_height = image_processor.size["height"]
patch_width = image_processor.patch_size["width"]
patch_height = image_processor.patch_size["height"]
if not (image_width <= target_width and image_height <= target_height): if not (image_width <= target_width and image_height <= target_height):
height_scale_factor = target_height / image_height height_scale_factor = target_height / image_height
...@@ -113,8 +115,8 @@ class FuyuProcessingInfo(BaseProcessingInfo): ...@@ -113,8 +115,8 @@ class FuyuProcessingInfo(BaseProcessingInfo):
image_height = int(image_height * optimal_scale_factor) image_height = int(image_height * optimal_scale_factor)
image_width = int(image_width * optimal_scale_factor) image_width = int(image_width * optimal_scale_factor)
ncols = math.ceil(image_width / 30) ncols = math.ceil(image_width / patch_width)
nrows = math.ceil(image_height / 30) nrows = math.ceil(image_height / patch_height)
return ncols, nrows return ncols, nrows
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
......
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/THUDM/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights."""
from argparse import Namespace
from typing import List, Literal, Mapping, Optional, TypedDict, Union
import torch
from torch import nn
from torch.nn import LayerNorm
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from transformers import PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from vllm.attention import AttentionMetadata
from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BatchFeature,
MultiModalFieldConfig,
PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig
from .chatglm import ChatGLMBaseModel, ChatGLMModel
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import flatten_bn, merge_multimodal_embeddings
class GLMVImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
class EVA2CLIPPatchEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.proj = nn.Conv2d(config.in_channels,
config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size)
self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
self.position_embedding = nn.Embedding(config.num_positions,
config.hidden_size)
def forward(self, images: torch.Tensor) -> torch.Tensor:
"""
Parameters:
images : torch.Tensor
Input image tensor with shape (B, C, H, W)
Returns:
torch.Tensor
Transformed tensor with shape (B, L, D)
"""
images = images.to(device=self.proj.weight.device,
dtype=self.proj.weight.dtype)
x = self.proj(images)
x = x.flatten(2).transpose(1, 2)
cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x += self.position_embedding.weight.unsqueeze(0)
return x
class EVA2CLIPAttention(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.hidden_size = config.hidden_size
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_rank = config.num_heads // self.tp_size
self.head_dim = config.hidden_size // config.num_heads
self.scale = self.head_dim**-0.5
self.query_key_value = QKVParallelLinear(
config.hidden_size,
self.head_dim,
config.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
self.scale)
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv, _ = self.query_key_value(x) # B, L, 3 * H * D
q, k, v = qkv.chunk(3, dim=-1)
out = self.attn(q, k, v)
output, _ = self.dense(out)
output = self.output_dropout(output)
return output
class EVA2CLIPMLP(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.fc1(x)
x = self.activation_fn(x)
x, _ = self.fc2(x)
return x
class EVA2CLIPTransformerLayer(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.input_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = EVA2CLIPAttention(config,
quant_config=quant_config,
prefix=f"{prefix}.attention")
self.mlp = EVA2CLIPMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.post_attention_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(self, hidden_states):
attention_input = hidden_states
attention_output = self.input_layernorm(
self.attention(attention_input))
hidden_states = attention_input + attention_output
mlp_input = hidden_states
mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
output = mlp_input + mlp_output
return output
class EVA2CLIPTransformer(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.layers = nn.ModuleList([
EVA2CLIPTransformerLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
def forward(self, hidden_states):
for layer_module in self.layers:
hidden_states = layer_module(hidden_states)
return hidden_states
class EVA2CLIPGLU(nn.Module):
def __init__(
self,
config,
in_features,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
"""
The original implementation is the same as:
```python
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size,
bias=False,
quant_config=quant_config
)
self.gate_proj = ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size,
bias=False,
quant_config=quant_config
)
```
```
gate_proj_output, _ = self.gate_proj(x)
dense_h_to_4h_output, _ = self.dense_h_to_4h(x)
x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1)
```
We merge two ColumnParallelLinear into one MergedColumnParallelLinear:
```
self.merged_proj = MergedColumnParallelLinear(
config.hidden_size,
[config.ffn_hidden_size] * 2,
bias=False,
quant_config=quant_config
)
```
```
x, _ = self.merged_proj(x)
```
"""
super().__init__()
self.linear_proj = ReplicatedLinear(in_features,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.linear_proj")
self.norm1 = nn.LayerNorm(config.hidden_size)
self.act1 = nn.GELU()
self.act2 = SiluAndMul()
self.merged_proj = MergedColumnParallelLinear(
config.hidden_size, [config.ffn_hidden_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.merged_proj")
self.dense_4h_to_h = RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h")
def forward(self, x):
x, _ = self.linear_proj(x)
x = self.act1(self.norm1(x))
x, _ = self.merged_proj(x)
x = self.act2(x)
x, _ = self.dense_4h_to_h(x)
return x
class EVA2CLIPModel(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
vision_config = Namespace(**config.vision_config)
self.patch_embedding = EVA2CLIPPatchEmbedding(vision_config)
self.transformer = EVA2CLIPTransformer(vision_config,
quant_config=quant_config,
prefix=f"{prefix}.transformer")
self.linear_proj = EVA2CLIPGLU(config,
in_features=config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.linear_proj")
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
out_channels=config.hidden_size,
kernel_size=2,
stride=2)
self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.scaling_factor = vision_config.scaling_factor
def forward(self, images: torch.Tensor) -> torch.Tensor:
"""
Parameters:
images : torch.Tensor
Input image tensor with shape (B, C, H, W)
Returns:
torch.Tensor
Transformed tensor with shape (B, L, D)
"""
x = self.patch_embedding(images)
x = self.transformer(x)
x = x[:, 1:]
b, s, h = x.shape
grid_size = int(s**0.5)
x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
x = self.conv(x)
x = x.flatten(2).transpose(1, 2)
x = self.linear_proj(x)
boi = self.boi.expand(x.shape[0], -1, -1)
eoi = self.eoi.expand(x.shape[0], -1, -1)
x = torch.cat((boi, x, eoi), dim=1)
x = x / self.scaling_factor
return x
class GLM4VModel(ChatGLMModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
quant_config = vllm_config.quant_config
self.vision = EVA2CLIPModel(self.config,
quant_config,
prefix=f"{prefix}.vision")
class GLM4VProcessor:
"""
This model doesn't define its own HF processor,
so we implement our own one here.
"""
def __init__(
self,
config: ChatGLMConfig,
tokenizer: PreTrainedTokenizer,
) -> None:
super().__init__()
self.config = config
self.tokenizer = tokenizer
vision_config = config.vision_config
image_size = vision_config["image_size"]
self.image_transform = transforms.Compose([
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
])
def __call__(
self,
text: Optional[Union[TextInput, list[TextInput]]] = None,
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
text_inputs = self.tokenizer(text)
if len(images) == 0:
image_inputs = {}
else:
pixel_values = [self.image_transform(image) for image in images]
image_inputs = {"pixel_values": torch.stack(pixel_values)}
return BatchFeature(
{
**text_inputs,
**image_inputs,
},
tensor_type=return_tensors,
)
class GLM4VProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(ChatGLMConfig)
def get_hf_processor(self, **kwargs: object) -> GLM4VProcessor:
return self.ctx.init_processor(
GLM4VProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
**kwargs,
)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_feature_tokens()}
def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
image_size = vision_config["image_size"]
patch_size = vision_config["patch_size"]
grid_length = image_size // patch_size // 2
return grid_length * grid_length
def get_num_image_feature_tokens(self) -> int:
# EVA2CLIPModel has embeddings for boi and eoi tokens as well
return self.get_num_image_tokens() + 2
class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
target_width = target_height = vision_config["image_size"]
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
return ProcessorInputs(
prompt_text=base_text * num_images,
mm_data=mm_data,
)
class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
def _hf_processor_applies_repl(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
return False
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.info.get_hf_config()
boi_token_id = hf_config.boi_token_id
image_token_id = hf_config.pad_token_id
eoi_token_id = hf_config.eoi_token_id
def get_replacement(item_idx: int):
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [image_token_id] * num_image_tokens
return [boi_token_id] + image_tokens + [eoi_token_id]
return [
PromptReplacement(
modality="image",
target=[boi_token_id, image_token_id, eoi_token_id],
replacement=get_replacement,
),
]
@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor,
info=GLM4VProcessingInfo,
dummy_inputs=GLM4VDummyInputsBuilder)
class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
SupportsMultiModal):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"],
"merged_proj": ["gate_proj", "dense_h_to_4h"]
}
# LoRA specific attributes
supported_lora_modules = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
# vision
"fc1",
"fc2",
"merged_proj",
"linear_proj"
]
embedding_modules = {}
embedding_padding_modules = []
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="transformer.encoder",
connector="transformer.vision.linear_proj",
tower_model="transformer.vision.transformer")
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
transformer_type: type[GLM4VModel] = GLM4VModel,
) -> None:
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
transformer_type=transformer_type,
)
self.transformer: GLM4VModel
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config["image_size"]
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[GLMVImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return GLMVImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
)
return None
def _process_image_input(
self, image_input: GLMVImagePixelInputs) -> torch.Tensor:
pixel_values = image_input["data"].to(dtype=self.config.torch_dtype)
return self.transformer.vision(pixel_values)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=[
self.config.boi_token_id,
self.config.pad_token_id,
self.config.eoi_token_id,
],
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
...@@ -15,9 +15,9 @@ from vllm.model_executor.layers.pooler import PoolerHead ...@@ -15,9 +15,9 @@ from vllm.model_executor.layers.pooler import PoolerHead
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import (PoolingMetadata, from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors) PoolingTensors)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (IntermediateTensors, PoolerOutput, from vllm.sequence import (IntermediateTensors, PoolerOutput,
PoolingSequenceGroupOutput) PoolingSequenceGroupOutput)
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -29,12 +29,7 @@ class GritLMPooler(nn.Module): ...@@ -29,12 +29,7 @@ class GritLMPooler(nn.Module):
self.model_config = model_config self.model_config = model_config
tokenizer = cached_get_tokenizer( tokenizer = cached_tokenizer_from_config(self.model_config)
self.model_config.tokenizer,
tokenizer_mode=self.model_config.tokenizer_mode,
tokenizer_revision=self.model_config.tokenizer_revision,
trust_remote_code=self.model_config.trust_remote_code,
)
# Collect the tokens needed for pattern matching. # Collect the tokens needed for pattern matching.
# "▁<" is different from "_<". The former uses "▁" to indicate that # "▁<" is different from "_<". The former uses "▁" to indicate that
......
...@@ -41,6 +41,7 @@ def resolve_h2ovl_min_max_num( ...@@ -41,6 +41,7 @@ def resolve_h2ovl_min_max_num(
dynamic_image_size: bool, dynamic_image_size: bool,
use_thumbnail: bool, use_thumbnail: bool,
) -> tuple[int, int]: ) -> tuple[int, int]:
min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
if use_thumbnail and max_dynamic_patch != 1: if use_thumbnail and max_dynamic_patch != 1:
...@@ -190,7 +191,7 @@ def image_to_pixel_values_h2ovl( ...@@ -190,7 +191,7 @@ def image_to_pixel_values_h2ovl(
pixel_values1, aspect_ratio1 = _preprocess_image( pixel_values1, aspect_ratio1 = _preprocess_image(
image, image,
input_size=input_size, input_size=input_size,
min_num=min_num, min_num=1,
max_num=max_num, max_num=max_num,
use_thumbnail=True, use_thumbnail=True,
prior_aspect_ratio=None, prior_aspect_ratio=None,
...@@ -199,7 +200,7 @@ def image_to_pixel_values_h2ovl( ...@@ -199,7 +200,7 @@ def image_to_pixel_values_h2ovl(
pixel_values2, _ = _preprocess_image( pixel_values2, _ = _preprocess_image(
image, image,
input_size=input_size, input_size=input_size,
min_num=3, # Hardcoded value min_num=3,
max_num=max_num, max_num=max_num,
use_thumbnail=True, use_thumbnail=True,
prior_aspect_ratio=aspect_ratio1, prior_aspect_ratio=aspect_ratio1,
...@@ -228,6 +229,7 @@ class H2OVLProcessor(BaseInternVLProcessor): ...@@ -228,6 +229,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
*, *,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None, dynamic_image_size: Optional[bool] = None,
use_msac: Optional[bool] = None, use_msac: Optional[bool] = None,
...@@ -235,6 +237,7 @@ class H2OVLProcessor(BaseInternVLProcessor): ...@@ -235,6 +237,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
super().__init__( super().__init__(
config, config,
tokenizer, tokenizer,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch, max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size, dynamic_image_size=dynamic_image_size,
) )
...@@ -267,11 +270,13 @@ class H2OVLProcessor(BaseInternVLProcessor): ...@@ -267,11 +270,13 @@ class H2OVLProcessor(BaseInternVLProcessor):
def resolve_min_max_num( def resolve_min_max_num(
self, self,
*, *,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None, dynamic_image_size: Optional[bool] = None,
use_thumbnail: Optional[bool] = None, use_thumbnail: Optional[bool] = None,
) -> tuple[int, int]: ) -> tuple[int, int]:
min_dynamic_patch = self.min_dynamic_patch min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch
is None else min_dynamic_patch)
max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
is None else max_dynamic_patch) is None else max_dynamic_patch)
dynamic_image_size = (self.dynamic_image_size if dynamic_image_size dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
...@@ -289,18 +294,21 @@ class H2OVLProcessor(BaseInternVLProcessor): ...@@ -289,18 +294,21 @@ class H2OVLProcessor(BaseInternVLProcessor):
def resolve_target_ratios( def resolve_target_ratios(
self, self,
*, *,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None, dynamic_image_size: Optional[bool] = None,
use_thumbnail: Optional[bool] = None, use_thumbnail: Optional[bool] = None,
prior_aspect_ratio: Optional[tuple[int, int]] = None, prior_aspect_ratio: Optional[tuple[int, int]] = None,
override_min_num: Optional[int] = None,
) -> list[tuple[int, int]]: ) -> list[tuple[int, int]]:
min_num, max_num = self.resolve_min_max_num( min_num, max_num = self.resolve_min_max_num(
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch, max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size, dynamic_image_size=dynamic_image_size,
use_thumbnail=use_thumbnail, use_thumbnail=use_thumbnail,
) )
if prior_aspect_ratio: # hardcoded value for second pass of use_msac if override_min_num is not None:
min_num = 3 min_num = override_min_num
return get_h2ovl_target_ratios( return get_h2ovl_target_ratios(
min_num, min_num,
...@@ -322,6 +330,7 @@ class H2OVLProcessor(BaseInternVLProcessor): ...@@ -322,6 +330,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
if use_msac: if use_msac:
target_ratios_1 = self.resolve_target_ratios( target_ratios_1 = self.resolve_target_ratios(
use_thumbnail=False, # Applied in calculate_targets use_thumbnail=False, # Applied in calculate_targets
override_min_num=1,
) )
num_patches_1, _, _, aspect_ratio_1 = calculate_h2ovl_targets( num_patches_1, _, _, aspect_ratio_1 = calculate_h2ovl_targets(
orig_width=image_width, orig_width=image_width,
...@@ -334,6 +343,7 @@ class H2OVLProcessor(BaseInternVLProcessor): ...@@ -334,6 +343,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
target_ratios_2 = self.resolve_target_ratios( target_ratios_2 = self.resolve_target_ratios(
use_thumbnail=False, # Applied in calculate_targets use_thumbnail=False, # Applied in calculate_targets
prior_aspect_ratio=aspect_ratio_1, prior_aspect_ratio=aspect_ratio_1,
override_min_num=3,
) )
num_patches_2, _, _, _ = calculate_h2ovl_targets( num_patches_2, _, _, _ = calculate_h2ovl_targets(
orig_width=image_width, orig_width=image_width,
...@@ -361,12 +371,14 @@ class H2OVLProcessor(BaseInternVLProcessor): ...@@ -361,12 +371,14 @@ class H2OVLProcessor(BaseInternVLProcessor):
def _images_to_pixel_values_lst( def _images_to_pixel_values_lst(
self, self,
images: list[Image.Image], images: list[Image.Image],
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None, dynamic_image_size: Optional[bool] = None,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
use_msac = self.use_msac if len(images) == 1 else False use_msac = self.use_msac if len(images) == 1 else False
min_num, max_num = self.resolve_min_max_num( min_num, max_num = self.resolve_min_max_num(
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch, max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size, dynamic_image_size=dynamic_image_size,
use_thumbnail=False, # Applied in image_to_pixel_values use_thumbnail=False, # Applied in image_to_pixel_values
...@@ -389,14 +401,23 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo): ...@@ -389,14 +401,23 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
def get_hf_processor( def get_hf_processor(
self, self,
*, *,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None, dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> H2OVLProcessor: ) -> H2OVLProcessor:
return H2OVLProcessor( if min_dynamic_patch is not None:
self.get_hf_config(), kwargs["min_dynamic_patch"] = min_dynamic_patch
self.get_tokenizer(), if max_dynamic_patch is not None:
max_dynamic_patch=max_dynamic_patch, kwargs["max_dynamic_patch"] = max_dynamic_patch
dynamic_image_size=dynamic_image_size, if dynamic_image_size is not None:
kwargs["dynamic_image_size"] = dynamic_image_size
return self.ctx.init_processor(
H2OVLProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
**kwargs,
) )
def get_mm_max_tokens_per_item( def get_mm_max_tokens_per_item(
......
...@@ -83,13 +83,15 @@ ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs] ...@@ -83,13 +83,15 @@ ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
class Idefics3ProcessingInfo(BaseProcessingInfo): class Idefics3ProcessingInfo(BaseProcessingInfo):
def get_hf_processor( def get_hf_processor(
self, self,
*, *,
size: Optional[Dict[str, int]] = None) -> Idefics3Processor: size: Optional[Dict[str, int]] = None,
**kwargs: object,
) -> Idefics3Processor:
if size is not None: if size is not None:
return self.ctx.get_hf_processor(Idefics3Processor, size=size) kwargs["size"] = size
return self.ctx.get_hf_processor(Idefics3Processor) return self.ctx.get_hf_processor(Idefics3Processor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
......
...@@ -7,6 +7,8 @@ import torch ...@@ -7,6 +7,8 @@ import torch
from typing_extensions import TypeIs, TypeVar from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import supports_kw from vllm.utils import supports_kw
from .interfaces_base import is_pooling_model from .interfaces_base import is_pooling_model
...@@ -445,3 +447,60 @@ def supports_cross_encoding( ...@@ -445,3 +447,60 @@ def supports_cross_encoding(
model: Union[Type[object], object], model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: ) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
return is_pooling_model(model) and _supports_cross_encoding(model) return is_pooling_model(model) and _supports_cross_encoding(model)
class SupportsQuant:
"""The interface required for all models that support quantization."""
packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {}
quant_config: Optional[QuantizationConfig] = None
def __new__(cls, *args, **kwargs) -> "SupportsQuant":
instance = super().__new__(cls)
quant_config = cls._find_quant_config(*args, **kwargs)
if quant_config is not None:
instance.quant_config = quant_config
instance.quant_config.packed_modules_mapping.update(
cls.packed_modules_mapping)
return instance
@staticmethod
def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
from vllm.config import VllmConfig # avoid circular import
args_values = list(args) + list(kwargs.values())
for arg in args_values:
if isinstance(arg, VllmConfig):
return arg.quant_config
if isinstance(arg, QuantizationConfig):
return arg
return None
@runtime_checkable
class SupportsTranscription(Protocol):
"""The interface required for all models that support transcription."""
supports_transcription: ClassVar[Literal[True]] = True
@overload
def supports_transcription(
model: Type[object]) -> TypeIs[Type[SupportsTranscription]]:
...
@overload
def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
...
def supports_transcription(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsTranscription]], TypeIs[SupportsTranscription]]:
if isinstance(model, type):
return isinstance(model, SupportsTranscription)
return isinstance(model, SupportsTranscription)
...@@ -120,6 +120,7 @@ def resolve_internvl_min_max_num( ...@@ -120,6 +120,7 @@ def resolve_internvl_min_max_num(
dynamic_image_size: bool, dynamic_image_size: bool,
use_thumbnail: bool, use_thumbnail: bool,
) -> tuple[int, int]: ) -> tuple[int, int]:
min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
if use_thumbnail and max_dynamic_patch != 1: if use_thumbnail and max_dynamic_patch != 1:
...@@ -247,6 +248,7 @@ class BaseInternVLProcessor(ABC): ...@@ -247,6 +248,7 @@ class BaseInternVLProcessor(ABC):
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
*, *,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None, dynamic_image_size: Optional[bool] = None,
) -> None: ) -> None:
...@@ -258,18 +260,22 @@ class BaseInternVLProcessor(ABC): ...@@ -258,18 +260,22 @@ class BaseInternVLProcessor(ABC):
image_size: int = config.vision_config.image_size image_size: int = config.vision_config.image_size
patch_size: int = config.vision_config.patch_size patch_size: int = config.vision_config.patch_size
if dynamic_image_size is None: if min_dynamic_patch is None:
dynamic_image_size = config.dynamic_image_size min_dynamic_patch = config.min_dynamic_patch
assert isinstance(dynamic_image_size, bool) assert isinstance(min_dynamic_patch, int)
if max_dynamic_patch is None: if max_dynamic_patch is None:
max_dynamic_patch = config.max_dynamic_patch max_dynamic_patch = config.max_dynamic_patch
assert isinstance(max_dynamic_patch, int) assert isinstance(max_dynamic_patch, int)
if dynamic_image_size is None:
dynamic_image_size = config.dynamic_image_size
assert isinstance(dynamic_image_size, bool)
self.num_image_token = int( self.num_image_token = int(
(image_size // patch_size)**2 * (config.downsample_ratio**2)) (image_size // patch_size)**2 * (config.downsample_ratio**2))
self.image_size = image_size self.image_size = image_size
self.min_dynamic_patch: int = config.min_dynamic_patch self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch self.max_dynamic_patch = max_dynamic_patch
self.dynamic_image_size = dynamic_image_size self.dynamic_image_size = dynamic_image_size
self.use_thumbnail: bool = config.use_thumbnail self.use_thumbnail: bool = config.use_thumbnail
...@@ -298,11 +304,13 @@ class BaseInternVLProcessor(ABC): ...@@ -298,11 +304,13 @@ class BaseInternVLProcessor(ABC):
def resolve_min_max_num( def resolve_min_max_num(
self, self,
*, *,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None, dynamic_image_size: Optional[bool] = None,
use_thumbnail: Optional[bool] = None, use_thumbnail: Optional[bool] = None,
) -> tuple[int, int]: ) -> tuple[int, int]:
min_dynamic_patch = self.min_dynamic_patch min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch
is None else min_dynamic_patch)
max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
is None else max_dynamic_patch) is None else max_dynamic_patch)
dynamic_image_size = (self.dynamic_image_size if dynamic_image_size dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
...@@ -320,11 +328,13 @@ class BaseInternVLProcessor(ABC): ...@@ -320,11 +328,13 @@ class BaseInternVLProcessor(ABC):
def resolve_target_ratios( def resolve_target_ratios(
self, self,
*, *,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None, dynamic_image_size: Optional[bool] = None,
use_thumbnail: Optional[bool] = None, use_thumbnail: Optional[bool] = None,
) -> list[tuple[int, int]]: ) -> list[tuple[int, int]]:
min_num, max_num = self.resolve_min_max_num( min_num, max_num = self.resolve_min_max_num(
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch, max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size, dynamic_image_size=dynamic_image_size,
use_thumbnail=use_thumbnail, use_thumbnail=use_thumbnail,
...@@ -355,10 +365,12 @@ class BaseInternVLProcessor(ABC): ...@@ -355,10 +365,12 @@ class BaseInternVLProcessor(ABC):
def _images_to_pixel_values_lst( def _images_to_pixel_values_lst(
self, self,
images: list[Image.Image], images: list[Image.Image],
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None, dynamic_image_size: Optional[bool] = None,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
min_num, max_num = self.resolve_min_max_num( min_num, max_num = self.resolve_min_max_num(
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch, max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size, dynamic_image_size=dynamic_image_size,
use_thumbnail=False, # Applied in image_to_pixel_values use_thumbnail=False, # Applied in image_to_pixel_values
...@@ -378,6 +390,7 @@ class BaseInternVLProcessor(ABC): ...@@ -378,6 +390,7 @@ class BaseInternVLProcessor(ABC):
self, self,
text: Optional[Union[str, list[str]]] = None, text: Optional[Union[str, list[str]]] = None,
images: Optional[Union[Image.Image, list[Image.Image]]] = None, images: Optional[Union[Image.Image, list[Image.Image]]] = None,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None, dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
...@@ -396,6 +409,7 @@ class BaseInternVLProcessor(ABC): ...@@ -396,6 +409,7 @@ class BaseInternVLProcessor(ABC):
else: else:
pixel_values_lst = self._images_to_pixel_values_lst( pixel_values_lst = self._images_to_pixel_values_lst(
images, images,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch, max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size, dynamic_image_size=dynamic_image_size,
) )
...@@ -451,8 +465,10 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo): ...@@ -451,8 +465,10 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
def get_hf_processor( def get_hf_processor(
self, self,
*, *,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None, dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> BaseInternVLProcessor: ) -> BaseInternVLProcessor:
raise NotImplementedError raise NotImplementedError
...@@ -642,14 +658,23 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo): ...@@ -642,14 +658,23 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo):
def get_hf_processor( def get_hf_processor(
self, self,
*, *,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None, dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> InternVLProcessor: ) -> InternVLProcessor:
return InternVLProcessor( if min_dynamic_patch is not None:
self.get_hf_config(), kwargs["min_dynamic_patch"] = min_dynamic_patch
self.get_tokenizer(), if max_dynamic_patch is not None:
max_dynamic_patch=max_dynamic_patch, kwargs["max_dynamic_patch"] = max_dynamic_patch
dynamic_image_size=dynamic_image_size, if dynamic_image_size is not None:
kwargs["dynamic_image_size"] = dynamic_image_size
return self.ctx.init_processor(
InternVLProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
**kwargs,
) )
......
...@@ -426,17 +426,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -426,17 +426,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
if self.scheduler_config is not None and \
not self.model_config.enforce_eager:
if self.scheduler_config.max_num_seqs > \
vllm_config.compilation_config.max_capture_size:
self.max_batch_size = \
vllm_config.compilation_config.max_capture_size
else:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.scheduler_config.max_num_seqs)
else:
self.max_batch_size = 8192 + 2
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
...@@ -453,16 +442,11 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -453,16 +442,11 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
num_mamba_layers = self.model_config.get_num_layers_by_block_type( num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba) self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager( self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
self.max_batch_size, *self._get_mamba_cache_shape()) *self._get_mamba_cache_shape())
(
mamba_cache_tensors, mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
state_indices_tensor,
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
**kwargs)
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
mamba_cache_tensors[1],
state_indices_tensor)
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_params, attn_metadata, mamba_cache_params,
intermediate_tensors, inputs_embeds) intermediate_tensors, inputs_embeds)
......
...@@ -134,6 +134,9 @@ class LlamaAttention(nn.Module): ...@@ -134,6 +134,9 @@ class LlamaAttention(nn.Module):
# MistralConfig has an optional head_dim introduced by Mistral-Nemo # MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(config, "head_dim", self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads) self.hidden_size // self.total_num_heads)
# Phi models introduced a partial_rotary_factor parameter in the config
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
...@@ -165,7 +168,7 @@ class LlamaAttention(nn.Module): ...@@ -165,7 +168,7 @@ class LlamaAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.rotary_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
...@@ -622,6 +625,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -622,6 +625,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
mistral_mapping = { mistral_mapping = {
"layers": "model.layers", "layers": "model.layers",
"attention": "self_attn", "attention": "self_attn",
"qscale_act": "input_scale",
"qscale_weight": "weight_scale",
"kv_fake_quantizer.qscale_act": "kv_scale",
"wq": "q_proj", "wq": "q_proj",
"wk": "k_proj", "wk": "k_proj",
"wv": "v_proj", "wv": "v_proj",
...@@ -750,15 +756,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -750,15 +756,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
modules = name.split(".") modules = name.split(".")
# rotary embeds should be sliced # rotary embeds should be sliced
if "wk" in modules: if "wk" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight, loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads) self.config.num_key_value_heads)
elif "wq" in modules: elif "wq" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight, loaded_weight = permute(loaded_weight,
self.config.num_attention_heads) self.config.num_attention_heads)
for item in modules: num_modules = len(modules)
if item in mapping and mapping[item] not in name: for i in range(num_modules):
item = modules[i]
next_item = modules[i + 1] if i < num_modules - 1 else None
combined_item = (f"{item}.{next_item}"
if next_item is not None else None)
if combined_item in mapping:
name = name.replace(combined_item, mapping[combined_item])
elif item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item]) name = name.replace(item, mapping[item])
return name, loaded_weight return name, loaded_weight
...@@ -119,7 +119,7 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): ...@@ -119,7 +119,7 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
return get_vision_encoder_info(self.get_hf_config()) return get_vision_encoder_info(self.get_hf_config())
@abstractmethod @abstractmethod
def get_hf_processor(self) -> LlavaLikeProcessor: def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
raise NotImplementedError raise NotImplementedError
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
...@@ -208,8 +208,8 @@ class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]): ...@@ -208,8 +208,8 @@ class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
class LlavaProcessingInfo(BaseLlavaProcessingInfo): class LlavaProcessingInfo(BaseLlavaProcessingInfo):
def get_hf_processor(self): def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(LlavaProcessor) return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]): class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
...@@ -272,8 +272,8 @@ class LlavaMultiModalProcessor( ...@@ -272,8 +272,8 @@ class LlavaMultiModalProcessor(
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo): class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
def get_hf_processor(self): def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(PixtralProcessor) return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
class PixtralHFMultiModalProcessor( class PixtralHFMultiModalProcessor(
...@@ -294,7 +294,7 @@ class PixtralHFMultiModalProcessor( ...@@ -294,7 +294,7 @@ class PixtralHFMultiModalProcessor(
pixel_values = processed_outputs.get("pixel_values") pixel_values = processed_outputs.get("pixel_values")
if pixel_values is not None: if pixel_values is not None:
# Before/after https://github.com/huggingface/transformers/pull/35122 # Before/after https://github.com/huggingface/transformers/pull/35122
if Version(TRANSFORMERS_VERSION) <= Version("4.48.2"): if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
images = mm_data["images"] images = mm_data["images"]
assert isinstance(images, list) assert isinstance(images, list)
...@@ -428,7 +428,7 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: ...@@ -428,7 +428,7 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
"""Given an signed vision feature layer, get the number of hidden layers """Given a signed vision feature layer, get the number of hidden layers
needed to leverage it. needed to leverage it.
Args: Args:
...@@ -438,7 +438,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: ...@@ -438,7 +438,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
""" """
if feature_layer_index < 0: if feature_layer_index < 0:
return num_hidden_layers + feature_layer_index + 1 return num_hidden_layers + feature_layer_index + 1
return feature_layer_index + 1 return feature_layer_index
def init_vision_tower_for_llava( def init_vision_tower_for_llava(
...@@ -742,23 +742,24 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -742,23 +742,24 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
class MantisProcessingInfo(LlavaProcessingInfo): class MantisProcessingInfo(LlavaProcessingInfo):
def get_hf_processor(self): def get_hf_processor(self, **kwargs: object):
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
vision_info = self.get_vision_encoder_info() vision_info = self.get_vision_encoder_info()
kwargs.setdefault("patch_size", vision_info.get_patch_size())
if Version(TRANSFORMERS_VERSION) < Version("4.48"): if Version(TRANSFORMERS_VERSION) < Version("4.48"):
# BUG: num_additional_image_tokens = 0 but treated as 1, # BUG: num_additional_image_tokens = 0 but treated as 1,
# so we set vision_feature_select_strategy to None to offset this # so we set vision_feature_select_strategy to None to offset this
vision_feature_select_strategy = None kwargs.setdefault("vision_feature_select_strategy", None)
else: else:
# FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150 # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
vision_feature_select_strategy = hf_config.vision_feature_select_strategy # noqa: E501 kwargs.setdefault(
"vision_feature_select_strategy",
hf_config.vision_feature_select_strategy,
)
return self.ctx.get_hf_processor( return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
LlavaProcessor,
patch_size=vision_info.get_patch_size(),
vision_feature_select_strategy=vision_feature_select_strategy,
)
class MantisMultiModalProcessor(LlavaMultiModalProcessor): class MantisMultiModalProcessor(LlavaMultiModalProcessor):
...@@ -819,7 +820,6 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -819,7 +820,6 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
prompt_ids, prompt_ids,
mm_item_counts, mm_item_counts,
) )
self._validate_mm_placeholders(mm_placeholders, mm_item_counts) self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
mm_placeholder_ranges = { mm_placeholder_ranges = {
......
...@@ -72,8 +72,8 @@ class LlavaNextProcessingInfo(BaseLlavaProcessingInfo): ...@@ -72,8 +72,8 @@ class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
def get_hf_config(self) -> LlavaNextLikeConfig: def get_hf_config(self) -> LlavaNextLikeConfig:
return self.ctx.get_hf_config(LlavaNextConfig) return self.ctx.get_hf_config(LlavaNextConfig)
def get_hf_processor(self): def get_hf_processor(self, **kwargs: object):
hf_processor = self.ctx.get_hf_processor(LlavaNextProcessor) hf_processor = self.ctx.get_hf_processor(LlavaNextProcessor, **kwargs)
# In case patch_size is omitted from `processor_config.json` # In case patch_size is omitted from `processor_config.json`
# e.g. for E5-V: https://huggingface.co/royokong/e5-v # e.g. for E5-V: https://huggingface.co/royokong/e5-v
......
...@@ -56,8 +56,8 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo): ...@@ -56,8 +56,8 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
def get_vision_encoder_info(self): def get_vision_encoder_info(self):
return get_vision_encoder_info(self.get_hf_config()) return get_vision_encoder_info(self.get_hf_config())
def get_hf_processor(self): def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(LlavaNextVideoProcessor) return self.ctx.get_hf_processor(LlavaNextVideoProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1} return {"video": 1}
......
...@@ -97,8 +97,8 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): ...@@ -97,8 +97,8 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
def get_hf_config(self) -> LlavaOnevisionLikeConfig: def get_hf_config(self) -> LlavaOnevisionLikeConfig:
return self.ctx.get_hf_config(LlavaOnevisionConfig) return self.ctx.get_hf_config(LlavaOnevisionConfig)
def get_hf_processor(self): def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(LlavaOnevisionProcessor) return self.ctx.get_hf_processor(LlavaOnevisionProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None} return {"image": None, "video": None}
...@@ -299,36 +299,69 @@ class LlavaOnevisionMultiModalProcessor( ...@@ -299,36 +299,69 @@ class LlavaOnevisionMultiModalProcessor(
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
) )
# LLaVA-OneVision processor doesn't support multiple videos
# with different sizes when converting back to tensors
# So, we process each component separately
# NOTE: No prompt replacement is applied in this case
processor = self.info.get_hf_processor() processor = self.info.get_hf_processor()
image_token = processor.image_token
video_token = processor.video_token video_token = processor.video_token
# LLaVA-OneVision processor doesn't support multiple videos text_outputs = super()._call_hf_processor(
# with different sizes when converting back to tensors
text_image_outputs = super()._call_hf_processor(
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data={},
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
) )
images = mm_data.pop("images", [])
assert isinstance(images, list)
if images:
processor_outputs = super()._call_hf_processor(
prompt=image_token * len(images),
mm_data={"images": images},
mm_kwargs=mm_kwargs,
)
image_outputs = {
k: v
for k, v in processor_outputs.items()
if k in ("pixel_values", "image_sizes")
}
else:
image_outputs = {}
pixel_values_videos = [] pixel_values_videos = []
for video in videos: for video in videos:
item_processor_data = dict(prompt=video_token, videos=video)
item_outputs = super()._call_hf_processor( item_outputs = super()._call_hf_processor(
prompt=prompt, prompt=video_token,
mm_data=item_processor_data, mm_data={"videos": video},
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
) )
pixel_values_videos.append( pixel_values_videos.append(item_outputs["pixel_values_videos"][0])
item_outputs.pop("pixel_values_videos")[0])
video_outputs = {"pixel_values_videos": pixel_values_videos}
combined_outputs = dict( combined_outputs = dict(
**text_image_outputs, text_outputs,
pixel_values_videos=pixel_values_videos, **image_outputs,
**video_outputs,
) )
return BatchFeature(combined_outputs) return BatchFeature(combined_outputs)
def _hf_processor_applies_repl(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
base_result = super()._hf_processor_applies_repl(
prompt_text=prompt_text,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
return base_result and mm_items.get_count("video", strict=False) == 0
def _get_prompt_replacements( def _get_prompt_replacements(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
......
...@@ -166,14 +166,13 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): ...@@ -166,14 +166,13 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \ assert not cache_config.enable_prefix_caching, \
"Mamba does not support prefix caching" "Mamba does not support prefix caching"
super().__init__() super().__init__()
self.config = config self.config = config
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.scheduler_config = scheduler_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.backbone = MambaModel(vllm_config=vllm_config, self.backbone = MambaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "backbone")) prefix=maybe_prefix(prefix, "backbone"))
...@@ -202,17 +201,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): ...@@ -202,17 +201,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.backbone.make_empty_intermediate_tensors) self.backbone.make_empty_intermediate_tensors)
if self.scheduler_config is not None and \
not self.model_config.enforce_eager:
if self.scheduler_config.max_num_seqs > \
vllm_config.compilation_config.max_capture_size:
self.max_batch_size = \
vllm_config.compilation_config.max_capture_size
else:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.scheduler_config.max_num_seqs)
else:
self.max_batch_size = 8192 + 2
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.backbone.get_input_embeddings(input_ids) return self.backbone.get_input_embeddings(input_ids)
...@@ -229,18 +217,10 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): ...@@ -229,18 +217,10 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
num_mamba_layers = self.model_config.get_num_layers_by_block_type( num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba) self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager( self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
self.max_batch_size, *self._get_mamba_cache_shape()) *self._get_mamba_cache_shape())
( mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
mamba_cache_tensors,
state_indices_tensor,
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
**kwargs)
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
mamba_cache_tensors[1],
state_indices_tensor)
hidden_states = self.backbone(input_ids, positions, attn_metadata, hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_params, intermediate_tensors, mamba_cache_params, intermediate_tensors,
......
# SPDX-License-Identifier: Apache-2.0
"""PyTorch MAMBA2 model."""
from typing import Iterable, List, Optional, Set, Tuple
import torch
from torch import nn
from transformers import MambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
KVCache = Tuple[torch.Tensor, torch.Tensor]
class Mamba2DecoderLayer(nn.Module):
def __init__(self,
config: MambaConfig,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.mixer = MambaMixer2(hidden_size=config.hidden_size,
ssm_state_size=config.state_size,
conv_kernel_size=config.conv_kernel,
intermediate_size=getattr(
config, "intermediate_size",
config.expand * config.hidden_size),
use_conv_bias=config.use_conv_bias,
use_bias=config.use_bias,
n_groups=config.n_groups,
num_heads=config.num_heads,
head_dim=config.head_dim,
rms_norm_eps=config.layer_norm_epsilon,
activation=config.hidden_act,
chunk_size=config.chunk_size,
quant_config=quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor],
**kwargs,
):
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, attn_metadata,
mamba_cache_params, sequence_idx)
return hidden_states, residual
class Mamba2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
is_lora_enabled = bool(lora_config)
assert not is_lora_enabled
self.config = config
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embeddings = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Mamba2DecoderLayer(config,
quant_config=quant_config),
prefix=f"{prefix}.layers")
self.norm_f = RMSNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
seq_idx = None
if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate(
zip(
attn_metadata.query_start_loc,
attn_metadata.query_start_loc[1:],
)):
seq_idx[srt:end] = i
seq_idx.unsqueeze_(0)
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer),
sequence_idx=seq_idx)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
"Mamba does not support prefix caching"
super().__init__()
self.config = config
self.vllm_config = vllm_config
self.scheduler_config = scheduler_config
self.model_config = vllm_config.model_config
self.backbone = Mamba2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "backbone"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.backbone.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.backbone.get_input_embeddings(input_ids)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_params, intermediate_tensors,
inputs_embeds)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
conv_state_shape, temporal_state_shape = None, None
intermediate_size = getattr(
self.config, "intermediate_size",
self.config.expand * self.config.hidden_size)
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = (
self.config.n_groups +
extra_groups_for_head_shards(self.config.n_groups, world_size))
# - heads and n_groups are TP-ed
conv_dim = (intermediate_size + 2 * n_groups * self.config.state_size)
conv_state_shape = (
divide(conv_dim, world_size),
self.config.conv_kernel - 1,
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
divide(self.config.num_heads, world_size),
self.config.head_dim,
self.config.state_size,
)
return conv_state_shape, temporal_state_shape
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "A_log" in name:
name = name.replace("A_log", "A")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List from typing import Dict, List, Tuple
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
@dataclass @dataclass
...@@ -23,8 +23,14 @@ class MambaCacheParams: ...@@ -23,8 +23,14 @@ class MambaCacheParams:
class MambaCacheManager: class MambaCacheManager:
def __init__(self, dtype, num_mamba_layers, max_batch_size, def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype,
conv_state_shape, temporal_state_shape): num_mamba_layers: int, conv_state_shape: Tuple[int, int],
temporal_state_shape: Tuple[int, int]):
# Determine max batch size to set size of MambaCache
max_batch_size = vllm_config.scheduler_config.max_num_seqs
if not vllm_config.model_config.enforce_eager:
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
conv_state_shape, conv_state_shape,
...@@ -42,8 +48,7 @@ class MambaCacheManager: ...@@ -42,8 +48,7 @@ class MambaCacheManager:
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size)) self.free_cache_indices = list(range(max_batch_size))
def current_run_tensors(self, input_ids: torch.Tensor, def current_run_tensors(self, **kwargs) -> MambaCacheParams:
attn_metadata: AttentionMetadata, **kwargs):
""" """
Return the tensors for the current run's conv and ssm state. Return the tensors for the current run's conv and ssm state.
""" """
...@@ -66,7 +71,8 @@ class MambaCacheManager: ...@@ -66,7 +71,8 @@ class MambaCacheManager:
(mamba_cache_tensors, (mamba_cache_tensors,
state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"] state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"]
return (mamba_cache_tensors, state_indices_tensor) return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1],
state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
""" """
......
...@@ -23,12 +23,12 @@ ...@@ -23,12 +23,12 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights.""" """Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from functools import partial from functools import partial
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Tuple, TypedDict, Union) Optional, Set, Tuple, TypedDict, Union)
import torch import torch
import torch.types
from torch import nn from torch import nn
from transformers import BatchFeature
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.whisper.modeling_whisper import ( from transformers.models.whisper.modeling_whisper import (
ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder) ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder)
...@@ -37,23 +37,21 @@ from vllm.attention import AttentionMetadata ...@@ -37,23 +37,21 @@ from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.parse import (ModalityData, ModalityDataItems, from vllm.multimodal.parse import (AudioItem, DictEmbeddingItems, ModalityData,
MultiModalDataItems, MultiModalDataParser, ModalityDataItems, MultiModalDataItems,
VideoItem) MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import PromptReplacement
PromptReplacement)
from vllm.multimodal.profiling import ProcessorInputs from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder, from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
MiniCPMVEmbeddingItems, MiniCPMVMultiModalDataParser, MiniCPMVMultiModalDataParser,
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo) MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
_minicpmv_field_config)
from .utils import AutoWeightsLoader, maybe_prefix from .utils import AutoWeightsLoader, maybe_prefix
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
MiniCPMOEmbeddingItems = MiniCPMVEmbeddingItems
class MiniCPMOAudioFeatureInputs(TypedDict): class MiniCPMOAudioFeatureInputs(TypedDict):
type: Literal["audio_features"] type: Literal["audio_features"]
...@@ -103,28 +101,52 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, ...@@ -103,28 +101,52 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
MiniCPMOAudioEmbeddingInputs] MiniCPMOAudioEmbeddingInputs]
class MiniCPMOAudioEmbeddingItems(MiniCPMOEmbeddingItems): def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))
return dict(
**_minicpmv_field_config(hf_inputs),
audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_num_slices=MultiModalFieldConfig.batched("audio"),
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
)
def __init__(self, data: Dict) -> None: class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems):
super().__init__(data, "audio")
audio_embeds = self.data.get("audio_embeds", None)
if audio_embeds is None:
raise ValueError("Incorrect type of video_embeds",
"Got type: None")
self.data["audio_embeds"] = audio_embeds
def get(self, index: int) -> object: def __init__(
return self.data["audio_embeds"][index] self,
data: Mapping[str, torch.Tensor],
fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None:
super().__init__(
data,
modality="image",
required_fields={"audio_embeds"},
fields_factory=fields_factory,
)
class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser): class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
def _parse_audio_data( def _parse_audio_data(
self, self,
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
) -> ModalityDataItems[Any, Any]: ) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict): if isinstance(data, dict):
return MiniCPMOAudioEmbeddingItems(data) return MiniCPMOAudioEmbeddingItems(
data,
fields_factory=_minicpmo_field_config,
)
return super()._parse_audio_data(data) return super()._parse_audio_data(data)
...@@ -167,6 +189,10 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -167,6 +189,10 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def get_max_audio_chunks_with_most_features(self) -> int: def get_max_audio_chunks_with_most_features(self) -> int:
return 30 return 30
def get_max_audio_tokens(self) -> int:
return self.get_max_audio_tokens_per_chunk(
) * self.get_max_audio_chunks_with_most_features()
def get_audio_len_by_num_chunks(self, num_chunks: int) -> int: def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
sampling_rate = self.get_default_audio_sampling_rate() sampling_rate = self.get_default_audio_sampling_rate()
# exclude <audio> </audio> # exclude <audio> </audio>
...@@ -194,7 +220,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -194,7 +220,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
return num_frames return num_frames
class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder): class MiniCPMODummyInputsBuilder(
MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
self, seq_len: int, mm_counts: Mapping[str, self, seq_len: int, mm_counts: Mapping[str,
...@@ -222,8 +249,7 @@ class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder): ...@@ -222,8 +249,7 @@ class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder):
class MiniCPMOMultiModalProcessor( class MiniCPMOMultiModalProcessor(
MiniCPMVMultiModalProcessor, MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]):
BaseMultiModalProcessor[MiniCPMOProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
return MiniCPMOMultiModalDataParser( return MiniCPMOMultiModalDataParser(
...@@ -369,21 +395,10 @@ class MiniCPMOMultiModalProcessor( ...@@ -369,21 +395,10 @@ class MiniCPMOMultiModalProcessor(
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0)) return _minicpmo_field_config(hf_inputs)
return dict(
**super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs),
audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_num_slices=MultiModalFieldConfig.batched("audio"),
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices))
class MultiModalProjector(nn.Module): class MultiModalProjector(nn.Module):
...@@ -406,7 +421,7 @@ class MultiModalProjector(nn.Module): ...@@ -406,7 +421,7 @@ class MultiModalProjector(nn.Module):
class MiniCPMWhisperEncoderLayer(nn.Module): class MiniCPMWhisperEncoderLayer(nn.Module):
def __init__(self, config: WhisperConfig, layer_idx: int = None): def __init__(self, config: WhisperConfig, layer_idx: int):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = WHISPER_ATTENTION_CLASSES[ self.self_attn = WHISPER_ATTENTION_CLASSES[
......
...@@ -35,6 +35,7 @@ import torch.types ...@@ -35,6 +35,7 @@ import torch.types
from PIL import Image from PIL import Image
from torch import nn from torch import nn
from transformers import BatchFeature, PretrainedConfig from transformers import BatchFeature, PretrainedConfig
from typing_extensions import TypeVar
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -51,9 +52,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -51,9 +52,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, PlaceholderRange) MultiModalInputs, PlaceholderRange)
from vllm.multimodal.parse import (ImageItem, ImageSize, ModalityData, from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageSize,
ModalityDataItems, MultiModalDataItems, ModalityData, ModalityDataItems,
MultiModalDataParser, VideoItem) MultiModalDataItems, MultiModalDataParser,
VideoItem)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
...@@ -115,93 +117,6 @@ class MiniCPMVImageEmbeddingInputs(TypedDict): ...@@ -115,93 +117,6 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
MiniCPMVImageEmbeddingInputs] MiniCPMVImageEmbeddingInputs]
class MiniCPMVEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]):
def __init__(self, data: Dict, modality: str) -> None:
super().__init__(data, modality)
def get_processor_data(self) -> Mapping[str, object]:
return self.data
def get_passthrough_data(self) -> Mapping[str, object]:
return {}
def get_count(self) -> int:
return len(self.data[f"{self.modality}_embeds"])
def get(self, index: int) -> Dict[str, torch.Tensor]:
out = {}
for k, v in self.data.items():
out[k] = v[index]
return out
class MiniCPMVImageEmbeddingItems(MiniCPMVEmbeddingItems):
def __init__(self, data: Dict) -> None:
super().__init__(data, "image")
image_embeds = self.data.get("image_embeds", None)
image_sizes = self.data.get("image_sizes", None)
if image_embeds is None:
raise ValueError("In correct type of image_embeds",
"Got type: None")
if not isinstance(image_embeds[0], torch.Tensor):
raise ValueError("In correct type of image_embeds",
f"Got type: {type(image_embeds[0])}")
if image_sizes is None:
raise ValueError(
"In correct type of image_sizes", "Got type: None."
"If you're using `image_size_list`, "
"please rename it to `image_sizes`")
if len(image_embeds[0].shape) == 2:
image_embeds = [image_embeds]
image_sizes = [image_sizes]
self.data["image_embeds"] = image_embeds
self.data["image_sizes"] = image_sizes
def get_image_size(self, index: int) -> ImageSize:
image_size = self.data["image_sizes"][index]
return ImageSize(width=image_size[0], height=image_size[1])
class MiniCPMVVideoEmbeddingItems(MiniCPMVEmbeddingItems):
def __init__(self, data: Dict) -> None:
super().__init__(data, "video")
video_embeds = self.data.get("video_embeds", None)
image_sizes = self.data.get("image_sizes", None)
num_frames = self.data.get("num_frames", None)
if video_embeds is None:
raise ValueError("In correct type of video_embeds",
"Got type: None")
if not isinstance(video_embeds[0], torch.Tensor):
raise ValueError("In correct type of video_embeds",
f"Got type: {type(video_embeds[0])}")
if image_sizes is None:
raise ValueError(
"In correct type of image_sizes", "Got type: None."
"If you're using `image_size_list`, "
"please rename it to `image_sizes`")
if num_frames is None:
raise ValueError("In correct type of numframes", "Got type: None")
if len(video_embeds[0].shape) == 2:
video_embeds = [video_embeds]
image_sizes = [image_sizes]
num_frames = [num_frames]
self.data["video_embeds"] = video_embeds
self.data["image_sizes"] = image_sizes
self.data["num_frames"] = num_frames
def get_frame_size(self, index: int) -> ImageSize:
frame_size = self.data["image_sizes"][index]
return ImageSize(width=frame_size[0], height=frame_size[1])
def get_num_frames(self, index: int) -> int:
return self.data["num_frames"][index]
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
...@@ -311,6 +226,77 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: ...@@ -311,6 +226,77 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
return tuple(int(x) for x in version_str.split(".")) return tuple(int(x) for x in version_str.split("."))
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0))
video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
image_num_slices=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
video_pixel_values=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_num_slices=MultiModalFieldConfig.batched("video"),
)
class MiniCPMVImageEmbeddingItems(DictEmbeddingItems):
def __init__(
self,
data: Mapping[str, torch.Tensor],
fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None:
super().__init__(
data,
modality="image",
required_fields={"image_embeds", "image_sizes"},
fields_factory=fields_factory,
)
def get_image_size(self, index: int) -> ImageSize:
image_size = self.get(index)["image_sizes"].tolist()
return ImageSize(width=image_size[0], height=image_size[1])
class MiniCPMVVideoEmbeddingItems(DictEmbeddingItems):
def __init__(
self,
data: Mapping[str, torch.Tensor],
fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None:
super().__init__(
data,
modality="video",
required_fields={"video_embeds", "video_image_sizes"},
fields_factory=fields_factory,
)
def get_frame_size(self, index: int) -> ImageSize:
frame_size = self.get(index)["video_image_sizes"].tolist()
return ImageSize(width=frame_size[0], height=frame_size[1])
def get_num_frames(self, index: int) -> int:
return len(self.get(index)["video_image_sizes"])
class MiniCPMVMultiModalDataParser(MultiModalDataParser): class MiniCPMVMultiModalDataParser(MultiModalDataParser):
def _parse_image_data( def _parse_image_data(
...@@ -318,7 +304,11 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser): ...@@ -318,7 +304,11 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]: ) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict): if isinstance(data, dict):
return MiniCPMVImageEmbeddingItems(data) return MiniCPMVImageEmbeddingItems(
data,
fields_factory=_minicpmv_field_config,
)
return super()._parse_image_data(data) return super()._parse_image_data(data)
def _parse_video_data( def _parse_video_data(
...@@ -326,7 +316,11 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser): ...@@ -326,7 +316,11 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]: ) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict): if isinstance(data, dict):
return MiniCPMVVideoEmbeddingItems(data) return MiniCPMVVideoEmbeddingItems(
data,
fields_factory=_minicpmv_field_config,
)
return super()._parse_video_data(data) return super()._parse_video_data(data)
...@@ -337,11 +331,8 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -337,11 +331,8 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
def get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config() return self.ctx.get_hf_config()
def get_hf_processor( def get_hf_processor(self, **kwargs: object):
self, hf_processor = self.ctx.get_hf_processor(**kwargs)
**kwargs: object,
):
hf_processor = self.ctx.get_hf_processor()
# NumPy arrays are considered as Iterable but not Sequence in # NumPy arrays are considered as Iterable but not Sequence in
# https://github.com/huggingface/transformers/blob/main/src/transformers/image_transforms.py#L428 # https://github.com/huggingface/transformers/blob/main/src/transformers/image_transforms.py#L428
...@@ -392,10 +383,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -392,10 +383,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return self.get_max_video_frame_tokens( return self.get_max_video_frame_tokens(
) * self.get_num_frames_with_most_features(seq_len) ) * self.get_num_frames_with_most_features(seq_len)
def get_max_audio_tokens(self) -> int:
return self.get_max_audio_tokens_per_chunk(
) * self.get_max_audio_chunks_with_most_features()
def get_slice_query_num(self) -> int: def get_slice_query_num(self) -> int:
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
query_num = getattr(hf_config, "query_num", 64) query_num = getattr(hf_config, "query_num", 64)
...@@ -476,8 +463,12 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -476,8 +463,12 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return ImageSize(width=image_size, height=image_size * num_slices) return ImageSize(width=image_size, height=image_size * num_slices)
class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo] _I = TypeVar("_I",
): bound=MiniCPMVProcessingInfo,
default=MiniCPMVProcessingInfo)
class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
self, self,
...@@ -514,8 +505,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo] ...@@ -514,8 +505,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo]
mm_data=mm_data) mm_data=mm_data)
class MiniCPMVMultiModalProcessor( class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
BaseMultiModalProcessor[MiniCPMVProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
return MiniCPMVMultiModalDataParser() return MiniCPMVMultiModalDataParser()
...@@ -675,7 +665,7 @@ class MiniCPMVMultiModalProcessor( ...@@ -675,7 +665,7 @@ class MiniCPMVMultiModalProcessor(
self.info.get_video_max_slice_num() self.info.get_video_max_slice_num()
) * inputs[modality]["num_frames"][index] ) * inputs[modality]["num_frames"][index]
else: else:
raise ValueError(f"UnExpected modality: {modality}") raise ValueError(f"Unexpected modality: {modality}")
def check_mm_inputs(self, inputs: Dict[str, object], def check_mm_inputs(self, inputs: Dict[str, object],
matches: List[str]) -> None: matches: List[str]) -> None:
...@@ -700,7 +690,7 @@ class MiniCPMVMultiModalProcessor( ...@@ -700,7 +690,7 @@ class MiniCPMVMultiModalProcessor(
inputs["video"]["video_image_sizes"][index], inputs["video"]["video_image_sizes"][index],
inputs["video"]["num_frames"][index]) inputs["video"]["num_frames"][index])
else: else:
raise ValueError(f"UnExpected modality: {modality}") raise ValueError(f"Unexpected modality: {modality}")
def call_base_hf_processor( def call_base_hf_processor(
self, self,
...@@ -742,6 +732,14 @@ class MiniCPMVMultiModalProcessor( ...@@ -742,6 +732,14 @@ class MiniCPMVMultiModalProcessor(
} }
} }
def _hf_processor_applies_repl(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
return False
def _get_prompt_replacements( def _get_prompt_replacements(
self, mm_items: MultiModalDataItems, self, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any], hf_processor_mm_kwargs: Mapping[str, Any],
...@@ -770,28 +768,10 @@ class MiniCPMVMultiModalProcessor( ...@@ -770,28 +768,10 @@ class MiniCPMVMultiModalProcessor(
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0)) return _minicpmv_field_config(hf_inputs)
video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0))
return dict(pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
image_num_slices=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
video_pixel_values=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_num_slices=MultiModalFieldConfig.batched("video"))
def apply( def apply(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment