Commit 0a89b8a3 authored by zhuwenwen's avatar zhuwenwen
Browse files

support qwen2_5-vl

parent 47bd229c
...@@ -5,7 +5,7 @@ requests >= 2.26.0 ...@@ -5,7 +5,7 @@ requests >= 2.26.0
tqdm tqdm
blake3 blake3
py-cpuinfo py-cpuinfo
transformers >= 4.48.2 # Required for Bamba model and Transformers backend. transformers >= 4.49.0 # Required for Bamba model and Transformers backend.
tokenizers >= 0.19.1 # Required for Llama 3. tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer. protobuf # Required by LlamaTokenizer.
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9' fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
......
...@@ -33,18 +33,18 @@ import torch.nn as nn ...@@ -33,18 +33,18 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import BatchFeature from transformers import BatchFeature
from transformers.models.qwen2_5_vl import (Qwen2_5_VLImageProcessor, from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
Qwen2_5_VLProcessor)
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -207,11 +207,12 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -207,11 +207,12 @@ class Qwen2_5_VisionAttention(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size() self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.hidden_size_per_attention_head = dist_utils.divide( self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads) projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide( self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, world_size) num_heads, self.tp_size)
self.qkv = ColumnParallelLinear(input_size=embed_dim, self.qkv = ColumnParallelLinear(input_size=embed_dim,
output_size=3 * projection_size, output_size=3 * projection_size,
...@@ -231,6 +232,29 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -231,6 +232,29 @@ class Qwen2_5_VisionAttention(nn.Module):
f"Qwen2.5-VL does not support {self.attn_backend} backend now." f"Qwen2.5-VL does not support {self.attn_backend} backend now."
) )
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
if self.tp_size > 1:
qkv = tensor_model_parallel_all_gather(qkv)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
q, k, v = qkv.chunk(3, dim=2)
# 3 * [s, b, head * head_dim]
if self.tp_size > 1:
splitter = partial(dist_utils.split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
v = splitter(v)[self.tp_rank]
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
q, k, v = (x.view(*new_shape) for x in (q, k, v))
return q, k, v
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
...@@ -240,22 +264,20 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -240,22 +264,20 @@ class Qwen2_5_VisionAttention(nn.Module):
# [s, b, c] --> [s, b, head * 3 * head_dim] # [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x) x, _ = self.qkv(x)
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim] # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
new_x_shape = x.size()[:-1] + ( q, k, v = self.split_qkv(x)
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
x = x.view(*new_x_shape)
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
batch_size = q.shape[1] batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
for x in (q, k, v)) for x in (q, k, v))
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) use_flash_attn = self.attn_backend == _Backend.FLASH_ATTN
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) q = apply_rotary_pos_emb_vision(q,
rotary_pos_emb,
use_flash_attn=use_flash_attn)
k = apply_rotary_pos_emb_vision(k,
rotary_pos_emb,
use_flash_attn=use_flash_attn)
if self.attn_backend == _Backend.FLASH_ATTN: if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import ( # from vllm_flash_attn.flash_attn_interface import (
...@@ -279,20 +301,23 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -279,20 +301,23 @@ class Qwen2_5_VisionAttention(nn.Module):
"(b s) ... -> b s ...", "(b s) ... -> b s ...",
b=batch_size) b=batch_size)
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == _Backend.TORCH_SDPA:
seq_length = q.size(1) # Execute attention entry by entry for speed & less VRAM.
q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) outputs = []
attention_mask = torch.zeros([1, seq_length, seq_length],
device=q.device,
dtype=torch.bool)
for i in range(1, len(cu_seqlens)): for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], start_idx = cu_seqlens[i - 1]
cu_seqlens[i - 1]:cu_seqlens[i]] = True end_idx = cu_seqlens[i]
output = F.scaled_dot_product_attention(q, q_i = q[:, start_idx:end_idx]
k, k_i = k[:, start_idx:end_idx]
v, v_i = v[:, start_idx:end_idx]
attention_mask, q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
dropout_p=0.0) for x in [q_i, k_i, v_i])
context_layer = rearrange(output, "b h s d -> b s h d ") output_i = F.scaled_dot_product_attention(q_i,
k_i,
v_i,
dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
...@@ -310,25 +335,6 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -310,25 +335,6 @@ class Qwen2_5_VisionAttention(nn.Module):
return output return output
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Qwen2_5_VisionBlock(nn.Module): class Qwen2_5_VisionBlock(nn.Module):
def __init__( def __init__(
...@@ -499,8 +505,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -499,8 +505,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
) )
# NOTE: We use torch native RMSNorm here for precision purposes. norm_layer = partial(RMSNorm, eps=norm_eps)
norm_layer = partial(Qwen2RMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
...@@ -665,24 +670,6 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -665,24 +670,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
if name.endswith("qkv.weight"):
visual_num_heads = self.num_heads
visual_embed_dim = self.hidden_size
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size,
visual_embed_dim)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
elif name.endswith("qkv.bias"):
visual_num_heads = self.num_heads
visual_embed_dim = self.hidden_size
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
...@@ -701,39 +688,20 @@ class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo): ...@@ -701,39 +688,20 @@ class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo):
*, *,
min_pixels: Optional[int] = None, min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None, max_pixels: Optional[int] = None,
fps: Optional[float] = 2.0, size: Optional[dict[str, int]] = None,
fps: Optional[Union[float, List[float]]] = None,
**kwargs: object,
) -> Qwen2_5_VLProcessor: ) -> Qwen2_5_VLProcessor:
hf_processor = self.ctx.get_hf_processor(Qwen2_5_VLProcessor) if fps is not None:
image_processor = hf_processor.image_processor # type: ignore kwargs["fps"] = fps
assert isinstance(image_processor, Qwen2_5_VLImageProcessor)
return self.ctx.get_hf_processor(
if min_pixels: Qwen2_5_VLProcessor,
image_processor.min_pixels = min_pixels image_processor=self.get_image_processor(min_pixels=min_pixels,
if max_pixels: max_pixels=max_pixels,
image_processor.max_pixels = max_pixels size=size),
if max_pixels or min_pixels: **kwargs,
image_processor.size = {
"min_pixels": image_processor.min_pixels,
"max_pixels": image_processor.max_pixels,
}
return hf_processor
def get_image_processor(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
fps: Optional[float] = 2.0,
) -> Qwen2_5_VLImageProcessor:
hf_processor = self.get_hf_processor(
min_pixels=min_pixels,
max_pixels=max_pixels,
fps=fps,
) )
image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2_5_VLImageProcessor)
return image_processor
class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
...@@ -760,19 +728,23 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -760,19 +728,23 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
"q_proj", "q_proj",
"k_proj", "k_proj",
"v_proj", "v_proj",
] ],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
} }
# LoRA specific attributes
# LoRA specific attributes, TODO: double check
supported_lora_modules = [ supported_lora_modules = [
# language model
"qkv_proj", "qkv_proj",
"o_proj", "o_proj",
"gate_up_proj", "gate_up_proj",
"down_proj", "down_proj", # Same name with vision encoder
"gate_proj"
"up_proj",
# vision tower # vision tower
"qkv", "qkv",
"gate_proj",
"up_proj",
"attn.proj", # Distinguish patch_embed.proj "attn.proj", # Distinguish patch_embed.proj
"fc1", "fc1",
"fc2", "fc2",
...@@ -780,6 +752,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -780,6 +752,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
"mlp.0", "mlp.0",
"mlp.2" "mlp.2"
] ]
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
...@@ -1130,4 +1103,4 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1130,4 +1103,4 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return MultiModelKeys.from_string_field( return MultiModelKeys.from_string_field(
language_model="language_model", language_model="language_model",
connector="visual.", connector="visual.",
tower_model="visual.merger.") tower_model="visual.merger.")
\ No newline at end of file
...@@ -58,14 +58,17 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -58,14 +58,17 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData, from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs, MultiModalFieldConfig, MultiModalKwargs,
VideoItem) VideoItem)
from vllm.multimodal.parse import (ImageSize, ModalityDataItems, from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
MultiModalDataItems, MultiModalDataParser) ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import _Backend from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import (
cached_image_processor_from_config)
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, from .utils import (AutoWeightsLoader, WeightsMapper,
...@@ -231,11 +234,15 @@ def apply_rotary_emb_torch(x: torch.Tensor, ...@@ -231,11 +234,15 @@ def apply_rotary_emb_torch(x: torch.Tensor,
def apply_rotary_pos_emb_vision(t: torch.Tensor, def apply_rotary_pos_emb_vision(t: torch.Tensor,
freqs: torch.Tensor) -> torch.Tensor: freqs: torch.Tensor,
use_flash_attn=False) -> torch.Tensor:
t_ = t.float() t_ = t.float()
cos = freqs.cos() cos = freqs.cos()
sin = freqs.sin() sin = freqs.sin()
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t) apply_rotary_emb = apply_rotary_emb_torch
if use_flash_attn:
from flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output return output
...@@ -341,20 +348,23 @@ class Qwen2VisionAttention(nn.Module): ...@@ -341,20 +348,23 @@ class Qwen2VisionAttention(nn.Module):
"(b s) ... -> b s ...", "(b s) ... -> b s ...",
b=batch_size) b=batch_size)
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == _Backend.TORCH_SDPA:
seq_length = q.size(1) # Execute attention entry by entry for speed & less VRAM.
q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) outputs = []
attention_mask = torch.zeros([1, seq_length, seq_length],
device=q.device,
dtype=torch.bool)
for i in range(1, len(cu_seqlens)): for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], start_idx = cu_seqlens[i - 1]
cu_seqlens[i - 1]:cu_seqlens[i]] = True end_idx = cu_seqlens[i]
output = F.scaled_dot_product_attention(q, q_i = q[:, start_idx:end_idx]
k, k_i = k[:, start_idx:end_idx]
v, v_i = v[:, start_idx:end_idx]
attention_mask, q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
dropout_p=0.0) for x in [q_i, k_i, v_i])
context_layer = rearrange(output, "b h s d -> b s h d ") output_i = F.scaled_dot_product_attention(q_i,
k_i,
v_i,
dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
...@@ -710,49 +720,25 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -710,49 +720,25 @@ class Qwen2VisionTransformer(nn.Module):
return loaded_params return loaded_params
class Qwen2VLEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
dict[str, torch.Tensor]]): image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_grid_sizes = image_grid_thw.prod(-1)
def __init__(self, data: dict, modality: str) -> None:
super().__init__(data, modality) video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
grid_thw = data[f"{modality}_grid_thw"]
slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist() return dict(
self._slices = [ pixel_values=MultiModalFieldConfig.flat_from_sizes(
slice(slice_idxs[i], slice_idxs[i + 1]) "image", image_grid_sizes),
for i in range(len(grid_thw)) image_embeds=MultiModalFieldConfig.flat_from_sizes(
] "image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
def get_count(self) -> int: pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
return len(self.data[f"{self.modality}_grid_thw"]) "video", video_grid_sizes),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
def get(self, index: int) -> dict[str, torch.Tensor]: "video", video_grid_sizes),
out = {} video_grid_thw=MultiModalFieldConfig.batched("video"),
for k, v in self.data.items(): )
if v != f"{self.modality}_grid_thw":
v = v[self._slices[index]]
out[k] = v
return out
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return self.data
class Qwen2VLImageEmbeddingItems(Qwen2VLEmbeddingItems):
def __init__(self, data: dict) -> None:
super().__init__(data, "image")
class Qwen2VLVideoEmbeddingItems(Qwen2VLEmbeddingItems):
def __init__(self, data: dict) -> None:
super().__init__(data, "video")
class Qwen2VLMultiModalDataParser(MultiModalDataParser): class Qwen2VLMultiModalDataParser(MultiModalDataParser):
...@@ -762,7 +748,12 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): ...@@ -762,7 +748,12 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]: ) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict): if isinstance(data, dict):
return Qwen2VLEmbeddingItems(data, modality="image") return DictEmbeddingItems(
data,
modality="image",
required_fields={"image_embeds", "image_grid_thw"},
fields_factory=_qwen2vl_field_config,
)
return super()._parse_image_data(data) return super()._parse_image_data(data)
...@@ -771,7 +762,12 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): ...@@ -771,7 +762,12 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]: ) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict): if isinstance(data, dict):
return Qwen2VLEmbeddingItems(data, modality="video") return DictEmbeddingItems(
data,
modality="video",
required_fields={"video_embeds", "video_grid_thw"},
fields_factory=_qwen2vl_field_config,
)
return super()._parse_video_data(data) return super()._parse_video_data(data)
...@@ -786,34 +782,64 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -786,34 +782,64 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
*, *,
min_pixels: Optional[int] = None, min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None, max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
**kwargs: object,
) -> Qwen2VLProcessor: ) -> Qwen2VLProcessor:
hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor) return self.ctx.get_hf_processor(
image_processor = hf_processor.image_processor # type: ignore Qwen2VLProcessor,
assert isinstance(image_processor, Qwen2VLImageProcessor) image_processor=self.get_image_processor(min_pixels=min_pixels,
max_pixels=max_pixels,
if min_pixels: size=size),
image_processor.min_pixels = min_pixels **kwargs,
if max_pixels: )
image_processor.max_pixels = max_pixels
if max_pixels or min_pixels: def _get_image_processor_kwargs(
image_processor.size = { self,
"min_pixels": image_processor.min_pixels, *,
"max_pixels": image_processor.max_pixels, min_pixels: Optional[int] = None,
} max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
return hf_processor **kwargs: object,
):
if self.ctx.model_config.mm_processor_kwargs:
kwargs.update(self.ctx.model_config.mm_processor_kwargs)
if min_pixels is not None:
kwargs["min_pixels"] = min_pixels
if size is None:
size = {"shortest_edge": min_pixels}
else:
size["shortest_edge"] = min_pixels
if max_pixels is not None:
kwargs["max_pixels"] = max_pixels
if size is None:
size = {"longest_edge": max_pixels}
else:
size["longest_edge"] = max_pixels
if size is not None:
kwargs["size"] = size
return kwargs
def get_image_processor( def get_image_processor(
self, self,
*, *,
min_pixels: Optional[int] = None, min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None, max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
**kwargs: object,
): ):
hf_processor = self.get_hf_processor(min_pixels=min_pixels, return cached_image_processor_from_config(
max_pixels=max_pixels) self.ctx.model_config,
image_processor = hf_processor.image_processor # type: ignore **self._get_image_processor_kwargs(min_pixels=min_pixels,
assert isinstance(image_processor, Qwen2VLImageProcessor) max_pixels=max_pixels,
return image_processor size=size,
**kwargs),
)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None} return {"image": None, "video": None}
...@@ -860,7 +886,11 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -860,7 +886,11 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
preprocessed_size = ImageSize(width=image_width, preprocessed_size = ImageSize(width=image_width,
height=image_height) height=image_height)
grid_t = max(num_frames // temporal_patch_size, 1) # NOTE: Frames are padded to be divisible by `temporal_patch_size`
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
padded_num_frames = num_frames + num_frames % temporal_patch_size
grid_t = max(padded_num_frames // temporal_patch_size, 1)
grid_h = preprocessed_size.height // patch_size grid_h = preprocessed_size.height // patch_size
grid_w = preprocessed_size.width // patch_size grid_w = preprocessed_size.width // patch_size
...@@ -945,14 +975,10 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -945,14 +975,10 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
max_image_tokens = self.get_max_image_tokens() * max_images max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len - max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens) max_image_tokens)
num_frames = min(max(max_total_frames // max(max_videos, 1), 1), max_frames_per_video = min(max_total_frames // max(max_videos, 1),
_MAX_FRAMES_PER_VIDEO) _MAX_FRAMES_PER_VIDEO)
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412 return max(max_frames_per_video, 1)
if num_frames > 1 and num_frames % 2 == 1:
num_frames += 1
return num_frames
def get_max_video_tokens(self, seq_len: int) -> int: def get_max_video_tokens(self, seq_len: int) -> int:
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
...@@ -1010,6 +1036,18 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ...@@ -1010,6 +1036,18 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2VLMultiModalDataParser() return Qwen2VLMultiModalDataParser()
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
self.info._get_image_processor_kwargs(**mm_kwargs),
)
def _get_prompt_replacements( def _get_prompt_replacements(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
...@@ -1022,8 +1060,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ...@@ -1022,8 +1060,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered
placeholder = { placeholder = {
"image": vocab[hf_processor.image_token], "image": vocab[hf_processor.image_token],
"video": vocab[hf_processor.video_token], "video": vocab[hf_processor.video_token],
...@@ -1052,24 +1088,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ...@@ -1052,24 +1088,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) return _qwen2vl_field_config(hf_inputs)
image_grid_sizes = image_grid_thw.prod(-1)
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"),
)
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,
...@@ -1449,4 +1468,4 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1449,4 +1468,4 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return MultiModelKeys.from_string_field( return MultiModelKeys.from_string_field(
language_model="language_model", language_model="language_model",
connector="visual.", connector="visual.",
tower_model="visual.merger.") tower_model="visual.merger.")
\ No newline at end of file
...@@ -353,17 +353,17 @@ class MultiModalFieldConfig: ...@@ -353,17 +353,17 @@ class MultiModalFieldConfig:
Example: Example:
.. code-block:: .. code-block::
Input: Input:
Data: [[AAAA] Data: [[AAAA]
[BBBB] [BBBB]
[CCCC]] [CCCC]]
Output: Output:
Element 1: [AAAA] Element 1: [AAAA]
Element 2: [BBBB] Element 2: [BBBB]
Element 3: [CCCC] Element 3: [CCCC]
""" """
return MultiModalFieldConfig( return MultiModalFieldConfig(
field=MultiModalBatchedField(), field=MultiModalBatchedField(),
...@@ -384,18 +384,18 @@ class MultiModalFieldConfig: ...@@ -384,18 +384,18 @@ class MultiModalFieldConfig:
Example: Example:
.. code-block:: .. code-block::
Given: Given:
slices: [slice(0, 3), slice(3, 7), slice(7, 9)] slices: [slice(0, 3), slice(3, 7), slice(7, 9)]
Input: Input:
Data: [AAABBBBCC] Data: [AAABBBBCC]
Output: Output:
Element 1: [AAA] Element 1: [AAA]
Element 2: [BBBB] Element 2: [BBBB]
Element 3: [CC] Element 3: [CC]
""" """
return MultiModalFieldConfig( return MultiModalFieldConfig(
field=MultiModalFlatField(slices=slices), field=MultiModalFlatField(slices=slices),
...@@ -416,18 +416,18 @@ class MultiModalFieldConfig: ...@@ -416,18 +416,18 @@ class MultiModalFieldConfig:
Example: Example:
.. code-block:: .. code-block::
Given: Given:
size_per_item: [3, 4, 2] size_per_item: [3, 4, 2]
Input: Input:
Data: [AAABBBBCC] Data: [AAABBBBCC]
Output: Output:
Element 1: [AAA] Element 1: [AAA]
Element 2: [BBBB] Element 2: [BBBB]
Element 3: [CC] Element 3: [CC]
See also: See also:
:func:`MultiModalFieldConfig.flat` :func:`MultiModalFieldConfig.flat`
...@@ -456,19 +456,19 @@ class MultiModalFieldConfig: ...@@ -456,19 +456,19 @@ class MultiModalFieldConfig:
Example: Example:
.. code-block:: .. code-block::
Given: Given:
batch_size: 4 batch_size: 4
Input: Input:
Data: [XYZ] Data: [XYZ]
Output: Output:
Element 1: [XYZ] Element 1: [XYZ]
Element 2: [XYZ] Element 2: [XYZ]
Element 3: [XYZ] Element 3: [XYZ]
Element 4: [XYZ] Element 4: [XYZ]
""" """
return MultiModalFieldConfig( return MultiModalFieldConfig(
field=MultiModalSharedField(batch_size), field=MultiModalSharedField(batch_size),
...@@ -738,4 +738,20 @@ class MultiModalInputs(TypedDict): ...@@ -738,4 +738,20 @@ class MultiModalInputs(TypedDict):
""" """
For each modality, information about the placeholder tokens in For each modality, information about the placeholder tokens in
:code:`prompt_token_ids`. :code:`prompt_token_ids`.
""" """
\ No newline at end of file
class MultiModalEncDecInputs(MultiModalInputs):
"""
Represents the outputs of :class:`vllm.multimodal.EncDecMultiModalProcessor`
ready to be passed to vLLM internals.
"""
encoder_prompt: str
"""The processed encoder prompt text."""
encoder_prompt_token_ids: list[int]
"""The processed token IDs of the encoder prompt."""
encoder_token_type_ids: NotRequired[list[int]]
"""The token type IDs of the encoder prompt."""
\ No newline at end of file
...@@ -9,13 +9,15 @@ from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar, ...@@ -9,13 +9,15 @@ from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar,
import numpy as np import numpy as np
import torch import torch
from PIL.Image import Image from PIL.Image import Image
from transformers import BatchFeature
from typing_extensions import TypeAlias, TypeGuard, assert_never from typing_extensions import TypeAlias, TypeGuard, assert_never
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .audio import resample_audio from .audio import resample_audio
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
ImageItem, ModalityData, MultiModalDataDict, VideoItem) ImageItem, ModalityData, MultiModalDataDict,
MultiModalFieldConfig, MultiModalKwargs, VideoItem)
_T = TypeVar("_T") _T = TypeVar("_T")
_I = TypeVar("_I") _I = TypeVar("_I")
...@@ -111,6 +113,64 @@ class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]], ...@@ -111,6 +113,64 @@ class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]],
return len(self.get(item_idx)) return len(self.get(item_idx))
class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
Mapping[str, torch.Tensor]]):
"""
Base class for data items that are expressed as a dictionary of tensors.
Usually, the dictionary keys correspond to the outputs of HF processor.
"""
def __init__(
self,
data: Mapping[str, torch.Tensor],
modality: str,
required_fields: set[str],
fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None:
super().__init__(data, modality)
missing_required_data_keys = required_fields - data.keys()
if missing_required_data_keys:
data_keys = set(data.keys())
msg = (f"The data should contain the fields: {required_fields}, "
f"but only found the following keys: {data_keys}")
raise ValueError(msg)
fields_config = fields_factory(data)
missing_required_fields = required_fields - fields_config.keys()
if missing_required_fields:
fields = set(fields_config.keys())
msg = f"{required_fields=} should be a subset of {fields=}"
raise ValueError(msg)
self.fields_config = fields_config
self.required_fields = required_fields
self._kwargs = MultiModalKwargs.from_hf_inputs(
BatchFeature(dict(data)),
fields_config,
)
def get_count(self) -> int:
return self._kwargs.get_item_count(self.modality)
def get(self, index: int) -> Mapping[str, torch.Tensor]:
return {
k: v.data
for k, v in self._kwargs.get_item(self.modality, index).items()
}
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return self.data
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]): class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
def __init__(self, data: Sequence[HfAudioItem]) -> None: def __init__(self, data: Sequence[HfAudioItem]) -> None:
...@@ -365,4 +425,4 @@ class MultiModalDataParser: ...@@ -365,4 +425,4 @@ class MultiModalDataParser:
mm_items[k] = subparsers[k](v) mm_items[k] = subparsers[k](v)
return mm_items return mm_items
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from functools import lru_cache from functools import lru_cache
from typing import Any, cast from typing import TYPE_CHECKING, Any, Union, cast
from transformers.processing_utils import ProcessorMixin from transformers.processing_utils import ProcessorMixin
from typing_extensions import TypeVar
if TYPE_CHECKING:
from vllm.config import ModelConfig
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
class HashableDict(dict):
"""
A dictionary that can be hashed by lru_cache.
"""
# NOTE: pythonic dict is not hashable,
# we override on it directly for simplicity
def __hash__(self) -> int: # type: ignore[override]
return hash(frozenset(self.items()))
class HashableList(list):
"""
A list that can be hashed by lru_cache.
"""
def __hash__(self) -> int: # type: ignore[override]
return hash(tuple(self))
def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs):
base_kwargs = model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
# NOTE: Pythonic dict is not hashable and will raise unhashable type
# error when calling `cached_get_processor`, therefore we need to
# wrap it to a hashable dict.
for key, value in merged_kwargs.items():
if isinstance(value, dict):
merged_kwargs[key] = HashableDict(value)
if isinstance(value, list):
merged_kwargs[key] = HashableList(value)
return merged_kwargs
def get_processor( def get_processor(
processor_name: str, processor_name: str,
*args: Any, *args: Any,
trust_remote_code: bool = False, trust_remote_code: bool = False,
processor_cls: type[ProcessorMixin] = ProcessorMixin, processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
**kwargs: Any, **kwargs: Any,
): ) -> _P:
"""Load a processor for the given model name via HuggingFace.""" """Load a processor for the given model name via HuggingFace."""
# don't put this import at the top level # don't put this import at the top level
# it will call torch.cuda.device_count() # it will call torch.cuda.device_count()
from transformers import AutoProcessor from transformers import AutoProcessor
processor_factory = (AutoProcessor processor_factory = (AutoProcessor if processor_cls == ProcessorMixin or
if processor_cls == ProcessorMixin else processor_cls) isinstance(processor_cls, tuple) else processor_cls)
try: try:
processor = processor_factory.from_pretrained( processor = processor_factory.from_pretrained(
...@@ -43,12 +87,30 @@ def get_processor( ...@@ -43,12 +87,30 @@ def get_processor(
else: else:
raise e raise e
return cast(ProcessorMixin, processor) if not isinstance(processor, processor_cls):
raise TypeError("Invalid type of HuggingFace processor. "
f"Expected type: {processor_cls}, but "
f"found type: {type(processor)}")
return processor
cached_get_processor = lru_cache(get_processor) cached_get_processor = lru_cache(get_processor)
def cached_processor_from_config(
model_config: "ModelConfig",
processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
**kwargs: Any,
) -> _P:
return cached_get_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
processor_cls=processor_cls, # type: ignore[arg-type]
**_merge_mm_kwargs(model_config, **kwargs),
)
def get_image_processor( def get_image_processor(
processor_name: str, processor_name: str,
*args: Any, *args: Any,
...@@ -85,6 +147,20 @@ def get_image_processor( ...@@ -85,6 +147,20 @@ def get_image_processor(
return cast(BaseImageProcessor, processor) return cast(BaseImageProcessor, processor)
cached_get_image_processor = lru_cache(get_image_processor)
def cached_image_processor_from_config(
model_config: "ModelConfig",
**kwargs: Any,
):
return cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**_merge_mm_kwargs(model_config, **kwargs),
)
def get_video_processor( def get_video_processor(
processor_name: str, processor_name: str,
*args: Any, *args: Any,
...@@ -104,3 +180,17 @@ def get_video_processor( ...@@ -104,3 +180,17 @@ def get_video_processor(
) )
return cast(BaseImageProcessor, processor.video_processor) return cast(BaseImageProcessor, processor.video_processor)
cached_get_video_processor = lru_cache(get_video_processor)
def cached_video_processor_from_config(
model_config: "ModelConfig",
**kwargs: Any,
):
return cached_get_video_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**_merge_mm_kwargs(model_config, **kwargs),
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment