Commit 4851c202 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.1' into v0.6.1-dev

parents 9b902f9e 3fd2b0d2
...@@ -105,7 +105,7 @@ def input_processor_for_clip( ...@@ -105,7 +105,7 @@ def input_processor_for_clip(
if isinstance(image_data, Image.Image): if isinstance(image_data, Image.Image):
image_feature_size = get_clip_image_feature_size(hf_config) image_feature_size = get_clip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0] num_images, image_feature_size, hidden_size = image_data.shape
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
else: else:
...@@ -355,6 +355,19 @@ class CLIPVisionTransformer(nn.Module): ...@@ -355,6 +355,19 @@ class CLIPVisionTransformer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override) num_hidden_layers_override=num_hidden_layers_override)
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {config.num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
elif len(self.encoder.layers) == config.num_hidden_layers:
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
else:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None
def forward( def forward(
self, self,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,
...@@ -364,7 +377,10 @@ class CLIPVisionTransformer(nn.Module): ...@@ -364,7 +377,10 @@ class CLIPVisionTransformer(nn.Module):
hidden_states = self.pre_layrnorm(hidden_states) hidden_states = self.pre_layrnorm(hidden_states)
hidden_states = self.encoder(inputs_embeds=hidden_states) hidden_states = self.encoder(inputs_embeds=hidden_states)
return hidden_states if self.post_layernorm is None:
return hidden_states
return self.post_layernorm(hidden_states)
class CLIPVisionModel(nn.Module): class CLIPVisionModel(nn.Module):
...@@ -386,9 +402,12 @@ class CLIPVisionModel(nn.Module): ...@@ -386,9 +402,12 @@ class CLIPVisionModel(nn.Module):
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override) num_hidden_layers_override=num_hidden_layers_override)
def forward(self, pixel_values: Optional[torch.Tensor] = None): @property
def _require_post_layernorm(self) -> bool:
return self.vision_model.post_layernorm is not None
return self.vision_model(pixel_values=pixel_values) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.vision_model(pixel_values)
@property @property
def device(self): def device(self):
...@@ -408,8 +427,10 @@ class CLIPVisionModel(nn.Module): ...@@ -408,8 +427,10 @@ class CLIPVisionModel(nn.Module):
for name, loaded_weight in weights: for name, loaded_weight in weights:
# post_layernorm is not needed in CLIPVisionModel # post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name: if ("vision_model.post_layernorm" in name
and not self._require_post_layernorm):
continue continue
# omit layers when num_hidden_layers_override is set # omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name: if "vision_model.encoder.layers." in name:
layer_idx = int(name.split(".")[3]) layer_idx = int(name.split(".")[3])
......
...@@ -47,6 +47,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -47,6 +47,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA
@torch.compile @torch.compile
def layer_norm_func(hidden_states, weight, variance_epsilon): def layer_norm_func(hidden_states, weight, variance_epsilon):
...@@ -292,8 +294,7 @@ class CohereModel(nn.Module): ...@@ -292,8 +294,7 @@ class CohereModel(nn.Module):
return hidden_states return hidden_states
class CohereForCausalLM(nn.Module): class CohereForCausalLM(nn.Module, SupportsLoRA):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -8,7 +8,7 @@ from transformers import PretrainedConfig ...@@ -8,7 +8,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather) tensor_model_parallel_all_gather)
...@@ -28,6 +28,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -28,6 +28,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class InternLM2MLP(nn.Module): class InternLM2MLP(nn.Module):
...@@ -234,6 +237,7 @@ class InternLM2Model(nn.Module): ...@@ -234,6 +237,7 @@ class InternLM2Model(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -243,11 +247,15 @@ class InternLM2Model(nn.Module): ...@@ -243,11 +247,15 @@ class InternLM2Model(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
InternLMDecoderLayer(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: InternLMDecoderLayer(config, cache_config,
]) quant_config),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.tok_embeddings(input_ids) return self.tok_embeddings(input_ids)
...@@ -260,21 +268,31 @@ class InternLM2Model(nn.Module): ...@@ -260,21 +268,31 @@ class InternLM2Model(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: IntermediateTensors = None, intermediate_tensors: IntermediateTensors = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is not None: if get_pp_group().is_first_rank:
hidden_states = inputs_embeds if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.tok_embeddings(input_ids)
residual = None
else: else:
hidden_states = self.tok_embeddings(input_ids) assert intermediate_tensors is not None
residual = None hidden_states = intermediate_tensors["hidden_states"]
for i in range(len(self.layers)): residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -298,6 +316,8 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -298,6 +316,8 @@ class InternLM2ForCausalLM(nn.Module):
self.output.weight = self.model.tok_embeddings.weight self.output.weight = self.model.tok_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -308,7 +328,7 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -308,7 +328,7 @@ class InternLM2ForCausalLM(nn.Module):
intermediate_tensors: IntermediateTensors, intermediate_tensors: IntermediateTensors,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -345,6 +365,8 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -345,6 +365,8 @@ class InternLM2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -353,6 +375,8 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -353,6 +375,8 @@ class InternLM2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
import itertools import itertools
import re
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -16,6 +17,7 @@ from transformers import PretrainedConfig ...@@ -16,6 +17,7 @@ from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
...@@ -26,6 +28,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -26,6 +28,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches) get_clip_num_patches)
...@@ -95,8 +98,8 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, ...@@ -95,8 +98,8 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
max_num: int, max_num: int, image_size: int,
image_size: int) -> Tuple[int, int, int]: use_thumbnail: bool) -> Tuple[int, int, int]:
aspect_ratio = orig_width / orig_height aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio # calculate the existing image aspect ratio
...@@ -114,17 +117,26 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, ...@@ -114,17 +117,26 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
target_width = image_size * target_aspect_ratio[0] target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1] target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# add thumbnail image if num_blocks > 1
if use_thumbnail and blocks > 1:
blocks += 1
return blocks, target_width, target_height return blocks, target_width, target_height
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int, def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
image_size: int, image_size: int,
use_thumbnail: int) -> List[Image.Image]: use_thumbnail: bool) -> List[Image.Image]:
orig_width, orig_height = image.size orig_width, orig_height = image.size
# calculate the number of blocks without thumbnail
blocks, target_width, target_height = calculate_num_blocks( blocks, target_width, target_height = calculate_num_blocks(
orig_width, orig_height, min_num, max_num, image_size) orig_width,
orig_height,
min_num,
max_num,
image_size,
use_thumbnail=False)
# resize the image # resize the image
resized_img = image.resize((target_width, target_height)) resized_img = image.resize((target_width, target_height))
processed_images = [] processed_images = []
...@@ -197,19 +209,25 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -197,19 +209,25 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
downsample_ratio) downsample_ratio)
image_data = multi_modal_data["image"] image_data = multi_modal_data["image"]
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
if isinstance(image_data, Image.Image): if isinstance(image_data, Image.Image):
width, height = image_data.size width, height = image_data.size
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
num_blocks, _, _ = calculate_num_blocks(width, height, min_num, num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
max_num, image_size) max_num, image_size,
# add thumbnail image if num_blocks > 1 use_thumbnail)
if hf_config.use_thumbnail and num_blocks > 1: image_feature_size = [num_blocks * num_patches]
num_blocks += 1 elif is_list_of(image_data, Image.Image):
image_feature_size = num_blocks * num_patches image_feature_size = []
for image in image_data:
width, height = image.size
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
max_num, image_size,
use_thumbnail)
image_feature_size.append(num_blocks * num_patches)
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0] num_images, image_feature_size, hidden_size = image_data.shape
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
...@@ -220,8 +238,14 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -220,8 +238,14 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
prompt_token_ids = llm_inputs["prompt_token_ids"] prompt_token_ids = llm_inputs["prompt_token_ids"]
if prompt is None: if prompt is None:
prompt = tokenizer.decode(prompt_token_ids) prompt = tokenizer.decode(prompt_token_ids)
image_prompt = IMG_START + IMG_CONTEXT * image_feature_size + IMG_END
new_prompt = prompt.replace('<image>', image_prompt, 1) new_prompt = prompt
image_idx = sorted(map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
for idx, feature_size in enumerate(image_feature_size, start=1):
image_prompt = IMG_START + IMG_CONTEXT * feature_size + IMG_END
if not image_idx:
image_prompt = f"Image-{idx}: {image_prompt}"
new_prompt = new_prompt.replace('<image>', image_prompt, 1)
new_prompt_token_ids = tokenizer.encode(new_prompt) new_prompt_token_ids = tokenizer.encode(new_prompt)
return LLMInputs(prompt=prompt, return LLMInputs(prompt=prompt,
...@@ -245,6 +269,15 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): ...@@ -245,6 +269,15 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
use_thumbnail=use_thumbnail) use_thumbnail=use_thumbnail)
# Add an N dimension for number of images per prompt (currently 1). # Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0) data = data.unsqueeze(0)
elif is_list_of(data, Image.Image):
data = [
image_to_pixel_values(img,
image_size,
min_num,
max_num,
use_thumbnail=use_thumbnail) for img in data
]
data = torch.stack(data)
model_config = ctx.model_config model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer, tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True) trust_remote_code=True)
...@@ -341,6 +374,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal): ...@@ -341,6 +374,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
nn.Linear(llm_hidden_size, llm_hidden_size)) nn.Linear(llm_hidden_size, llm_hidden_size))
self.img_context_token_id = None self.img_context_token_id = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def pixel_shuffle(self, x, scale_factor=0.5): def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size() n, w, h, c = x.size()
...@@ -446,7 +481,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal): ...@@ -446,7 +481,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> SamplerOutput:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None: if image_input is not None and get_pp_group().is_first_rank:
inputs_embeds = self.language_model.model.get_input_embeddings( inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids) input_ids)
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
...@@ -461,7 +496,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal): ...@@ -461,7 +496,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
......
...@@ -38,6 +38,8 @@ from vllm.sequence import IntermediateTensors ...@@ -38,6 +38,8 @@ from vllm.sequence import IntermediateTensors
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size) _get_graph_batch_size)
from .interfaces import SupportsLoRA
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -539,7 +541,7 @@ class JambaModel(nn.Module): ...@@ -539,7 +541,7 @@ class JambaModel(nn.Module):
return hidden_states return hidden_states
class JambaForCausalLM(nn.Module, HasInnerState): class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -731,7 +733,7 @@ class JambaForCausalLM(nn.Module, HasInnerState): ...@@ -731,7 +733,7 @@ class JambaForCausalLM(nn.Module, HasInnerState):
indices_for_current_run: List[int]): indices_for_current_run: List[int]):
# move out all of the occupied but currently not running blocks # move out all of the occupied but currently not running blocks
# outside of the first n blocks # outside of the first n blocks
destination_indices = set([range(batch_size)]) destination_indices = range(batch_size)
max_possible_batch_size = self.mamba_cache[0].shape[1] max_possible_batch_size = self.mamba_cache[0].shape[1]
for destination_index in destination_indices: for destination_index in destination_indices:
if destination_index in self._get_all_occupied_indices() and \ if destination_index in self._get_all_occupied_indices() and \
......
...@@ -387,6 +387,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -387,6 +387,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"gate_proj": ("gate_up_proj", 0), "gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1), "up_proj": ("gate_up_proj", 1),
} }
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
"wo": "o_proj",
"attention_norm": "input_layernorm",
"feed_forward": "mlp",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm"
}
def __init__( def __init__(
self, self,
...@@ -493,6 +512,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -493,6 +512,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if ("rotary_emb.cos_cached" in name if ("rotary_emb.cos_cached" in name
...@@ -642,3 +663,33 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -642,3 +663,33 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
else: else:
raise RuntimeError("Self attention has no KV cache scaling " raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!") "factor attribute!")
# This function is used to remap the mistral format as
# used by Mistral and Llama <=2
def maybe_remap_mistral(
self, name: str,
loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]:
def permute(w, n_heads):
attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size
return w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
mapping = self.mistral_mapping
modules = name.split(".")
# rotary embeds should be sliced
if "wk" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
elif "wq" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)
for item in modules:
if item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item])
return name, loaded_weight
...@@ -4,6 +4,7 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, ...@@ -4,6 +4,7 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
...@@ -16,6 +17,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -16,6 +17,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip, from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens, dummy_seq_data_for_clip, get_max_clip_image_tokens,
...@@ -24,7 +26,7 @@ from .interfaces import SupportsMultiModal ...@@ -24,7 +26,7 @@ from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens, dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip) input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model, from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings) merge_multimodal_embeddings)
...@@ -133,7 +135,18 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -133,7 +135,18 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
hf_config = ctx.get_hf_config(LlavaConfig) hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
image_feature_size = get_max_llava_image_tokens(ctx) image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_feature_size = get_max_llava_image_tokens(ctx)
elif is_list_of(image_data, Image.Image):
image_feature_size = [get_max_llava_image_tokens(ctx)
] * len(image_data)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip( return input_processor_for_clip(
...@@ -230,29 +243,24 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -230,29 +243,24 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
return None return None
if pixel_values is not None: if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor): if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return LlavaImagePixelInputs( return LlavaImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
) )
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor): if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}") f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return LlavaImageEmbeddingInputs( return LlavaImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=image_embeds, data=flatten_bn(image_embeds, concat=True),
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
......
...@@ -234,7 +234,9 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -234,7 +234,9 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
for img in image_data for img in image_data
] ]
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0] num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
......
import itertools
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import numpy as np
import torch
import torch.nn as nn
from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
SiglipVisionConfig)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
merge_multimodal_embeddings)
logger = init_logger(__name__)
# For profile run
_MAX_FRAMES_PER_VIDEO = 32
_MAX_NUM_VIDEOS = 1
class LlavaNextVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, num_frames, num_channels, height, width)`
Note that `num_frames` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
Note that it only supports one video input for one batch.
"""
def get_llava_next_video_frame_feature_size(
hf_config: LlavaNextVideoConfig) -> int:
# Support both CLIPVisionConfig and SiglipVisionConfig
image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
spatial_pool_stride = hf_config.spatial_pool_stride
return int((image_size / patch_size / spatial_pool_stride)**2)
def _get_max_llm_tokens(ctx: InputContext) -> int:
"""
Calculated from the maximum video frames under the context length
constraints of the language model.
"""
hf_text_config = ctx.model_config.hf_text_config
model_config = ctx.model_config
max_tokens = model_config.max_model_len
rope_scaling = model_config.rope_scaling
if rope_scaling:
rope_scaling_factor = hf_text_config.rope_scaling["factor"]
else:
rope_scaling_factor = 1
max_tokens *= rope_scaling_factor
return max_tokens
def get_max_llava_next_video_tokens(ctx: InputContext) -> int:
# Currently set to 32 frames
# TODO: max_tokens = _get_max_llm_tokens(ctx)
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
return _MAX_FRAMES_PER_VIDEO * tokens_per_frame
def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
vision_config = hf_config.vision_config
# TODO: support multiple videos
num_videos = mm_counts["video"]
if num_videos != _MAX_NUM_VIDEOS:
raise NotImplementedError(
f"Only {_MAX_NUM_VIDEOS} videos are supported")
# TODO: support configuring the number of frames
frames_per_video = _MAX_FRAMES_PER_VIDEO
# num_images = num_videos * frames_per_video
# fills the sequence with as longer video data as possible
tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
video_feature_size = frames_per_video * tokens_per_frame
if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
)
pil_frame = dummy_image_for_clip(vision_config, num_images=1)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
mm_data = {"video": mm_data_per_video}
return seq_data, mm_data
elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
)
pil_frame = dummy_image_for_siglip(vision_config, num_images=1)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
mm_data = {"video": mm_data_per_video}
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_for_llava_next_video(ctx: InputContext,
llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data:
return llm_inputs
video_data = multi_modal_data["video"]
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
vision_config = hf_config.vision_config
if isinstance(video_data, np.ndarray):
# Supports both CLIP and Siglip
num_frames = video_data.shape[0]
frame_feature_size = \
get_llava_next_video_frame_feature_size(hf_config)
video_feature_size = num_frames * frame_feature_size
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size,
)
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
elif is_list_of(video_data, np.ndarray):
raise NotImplementedError(
"Processing multiple videos is not supported")
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaNextVideoConfig):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
# adopted from transformers modeling_llava_next_video.py
class LlavaNextVideoPooler(nn.Module):
def __init__(self, config):
super().__init__()
mode = config.spatial_pool_mode
stride = config.spatial_pool_stride
image_size = config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.image_size = image_size // patch_size**2
if mode == "average":
self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride)
elif mode == "max":
self.pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
else:
# TODO: Support Conv2d pooling layer, need to load weights
raise ValueError(
f"Unknown pooling mode: {mode}. Expected [`average`, `max`]")
def forward(self, image_features):
ori_width = int(
math.sqrt(image_features.shape[1] * self.image_size //
self.image_size))
ori_height = int(ori_width * self.image_size // self.image_size)
batch_size, _, dim = image_features.shape
image_features_spatial = image_features \
.view(batch_size, ori_height, ori_height, dim) \
.permute(0, 3, 1, 2)
image_features_spatial = self.pool(image_features_spatial)
return image_features_spatial.flatten(2).transpose(1, 2).contiguous()
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, text_hidden_size: int,
projector_hidden_act: str):
super().__init__()
self.linear_1 = nn.Linear(vision_hidden_size,
text_hidden_size,
bias=True)
self.act = get_act_fn(projector_hidden_act)
self.linear_2 = nn.Linear(text_hidden_size,
text_hidden_size,
bias=True)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
@MULTIMODAL_REGISTRY.register_input_mapper("video")
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"video", get_max_llava_next_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video)
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: LlavaNextVideoConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer
self.vision_tower = _init_vision_tower(config)
self.multi_modal_projector = LlavaNextMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
self.vision_resampler = LlavaNextVideoPooler(config)
def _validate_video_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[2:])
if actual_dims != expected_dims:
expected_expr = ("num_frames", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values in each video frame "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[LlavaNextVideoPixelInputs]:
"""
A legal video input should have the following dimensions:
{
"pixel_values_videos" :
List[b, Tensor(nb_frames, nb_channels, height, width)]
}
"""
pixel_values = kwargs.pop("pixel_values_videos", None)
if pixel_values is None:
return None
if not (is_list_of(pixel_values,
(torch.Tensor)) # different shape videos
or isinstance(pixel_values,
torch.Tensor)): # same shape videos
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaNextVideoPixelInputs(
type="pixel_values_videos",
data=pixel_values,
)
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _video_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
image_features = self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
)
image_features = self.vision_resampler(image_features)
image_features = self.multi_modal_projector(image_features)
return image_features
def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs):
assert self.vision_tower is not None
video_pixels = inputs["data"]
if isinstance(video_pixels, torch.Tensor):
# TODO: support multiple videos per input
b, num_videos, num_frames, c, h, w = video_pixels.shape
assert (num_videos == 1)
stacked_pixels = video_pixels.view(b * num_videos * num_frames, c,
h, w)
stacked_embeddings = self._video_pixels_to_features(
self.vision_tower, stacked_pixels)
return stacked_embeddings.view(b, num_frames,
*stacked_embeddings.shape[1:])
elif is_list_of(video_pixels, torch.Tensor):
frames_per_videos = [v.shape[0] for v in video_pixels]
stacked_pixels = torch.cat(video_pixels, dim=0)
stacked_embeddings = self._video_pixels_to_features(
self.vision_tower, stacked_pixels)
return torch.split(stacked_embeddings, frames_per_videos, dim=0)
else:
raise ValueError(
f"Unsupported type of video input {type(video_pixels)}")
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
"""Run forward pass for LlaVA-NeXT-Video.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values_videos: Pixels in each frames for each input videos.
"""
video_input = self._parse_and_validate_video_input(**kwargs)
# merge video embeddings into input embeddings
if video_input is not None:
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = self.language_model \
.model.get_input_embeddings(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
None,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
weights, 4)
# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
...@@ -26,11 +26,9 @@ import re ...@@ -26,11 +26,9 @@ import re
from array import array from array import array
from functools import partial from functools import partial
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple, from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict)
import numpy as np
import torch import torch
import torch.nn.functional as F
import torch.types import torch.types
from PIL import Image from PIL import Image
from torch import nn from torch import nn
...@@ -44,6 +42,8 @@ from vllm.logger import init_logger ...@@ -44,6 +42,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (Resampler2,
get_2d_sincos_pos_embed)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
...@@ -98,101 +98,6 @@ MiniCPMVImageInputs = MiniCPMVImagePixelInputs ...@@ -98,101 +98,6 @@ MiniCPMVImageInputs = MiniCPMVImagePixelInputs
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor):
# abs_pos: L, C
# tgt_size: (H, W)
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
# tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
return (F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size[0], tgt_size[1]),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype))
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(
embed_dim: int,
grid_size: Union[int, Tuple[int, int]],
cls_token: bool = False,
version: Tuple[int, int] = (2, 0),
):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_h_size, grid_w_size = grid_size, grid_size
else:
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
grid_h = np.arange(grid_h_size, dtype=np.float32)
grid_w = np.arange(grid_w_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
if version == (2, 0):
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
axis=0)
else:
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim: int,
grid: np.ndarray,
version: Tuple[int, int] = (2, 0)):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
if version == (2, 0):
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
else:
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim: int,
pos: np.ndarray,
version: Tuple[int, int] = (2, 0)):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,) / (H, W)
out: (M, D) / (H, W, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
if version == (2, 0):
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
else:
out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
emb_sin = np.sin(out) # (H, W, D/2)
emb_cos = np.cos(out) # (H, W, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
return emb
class BaseResampler(nn.Module): class BaseResampler(nn.Module):
""" """
A 2D perceiver-resampler network with one cross attention layers by A 2D perceiver-resampler network with one cross attention layers by
...@@ -245,62 +150,6 @@ class BaseResampler(nn.Module): ...@@ -245,62 +150,6 @@ class BaseResampler(nn.Module):
return query.unsqueeze(1).repeat(1, N, 1) return query.unsqueeze(1).repeat(1, N, 1)
class Resampler2(BaseResampler):
def __init__(
self,
grid_size: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
adaptive: bool = False,
) -> None:
super().__init__(grid_size**2, embed_dim, num_heads, kv_dim,
norm_layer)
self.adaptive = adaptive
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
grid_size,
version=(2, 0))
self.pos_embed = nn.Parameter(
torch.from_numpy(pos_embed_arr).float()).requires_grad_(False)
self.apply(self._init_weights)
def forward(
self,
x: torch.Tensor,
tgt_sizes: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
):
if self.adaptive:
pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
tgt_sizes,
version=(2, 0))
pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device,
dtype=x.dtype)
else:
pos_embed = get_abs_pos(self.pos_embed, tgt_sizes)
x, _ = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask,
)[0]
x = out.permute(1, 0, 2)
x = self.ln_post(x)
x = x @ self.proj
return x
class Resampler2_5(BaseResampler): class Resampler2_5(BaseResampler):
def __init__( def __init__(
...@@ -782,7 +631,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel): ...@@ -782,7 +631,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
num_heads=embed_dim // 128, num_heads=embed_dim // 128,
grid_size=int(math.sqrt(self.config.query_num)), grid_size=int(math.sqrt(self.config.query_num)),
kv_dim=vision_dim, kv_dim=vision_dim,
adaptive=True, adaptive=False,
do_post_projection=True,
) )
return resampler return resampler
......
...@@ -435,7 +435,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -435,7 +435,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue continue
# Skip layers on other devices. # Skip layers on other devices.
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
...@@ -454,6 +455,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -454,6 +455,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
# Skip layers on other devices. # Skip layers on other devices.
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, weight_loader(param,
...@@ -464,7 +468,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -464,7 +468,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
break break
else: else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue continue
# Skip layers on other devices. # Skip layers on other devices.
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
......
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -13,7 +14,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -13,7 +14,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.gemma import GemmaModel from vllm.model_executor.models.gemma import GemmaForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
...@@ -22,14 +23,10 @@ from vllm.sequence import IntermediateTensors ...@@ -22,14 +23,10 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens) dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import merge_multimodal_embeddings from .utils import filter_weights, merge_multimodal_embeddings
logger = init_logger(__name__) logger = init_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = {
"language_model.model": "language_model",
}
class PaliGemmaImagePixelInputs(TypedDict): class PaliGemmaImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
...@@ -151,8 +148,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -151,8 +148,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
projection_dim=config.vision_config.projection_dim) projection_dim=config.vision_config.projection_dim)
self.quant_config = quant_config self.quant_config = quant_config
self.language_model = GemmaModel(config.text_config, cache_config, self.language_model = GemmaForCausalLM(config.text_config,
quant_config) cache_config, quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size self.unpadded_vocab_size = config.text_config.vocab_size
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
...@@ -252,7 +249,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -252,7 +249,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
vision_embeddings = vision_embeddings * (self.config.hidden_size** vision_embeddings = vision_embeddings * (self.config.hidden_size**
-0.5) -0.5)
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
...@@ -262,87 +260,47 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -262,87 +260,47 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
else: else:
inputs_embeds = None inputs_embeds = None
hidden_states = self.language_model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None, None,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
# Copied from vllm/model_executor/models/gemma.py
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.language_model.embed_tokens, return self.language_model.compute_logits(hidden_states,
hidden_states, sampling_metadata) sampling_metadata)
return logits
# Copied from vllm/model_executor/models/gemma.py
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
return next_tokens
# Adapted from vllm/model_executor/models/gemma.py
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ # prepare weight iterators for components
# (param_name, shard_name, shard_id) vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), # load vision tower
("qkv_proj", "v_proj", "v"), vit_weights = filter_weights(vit_weights, "vision_tower")
("gate_up_proj", "gate_proj", 0), self.vision_tower.load_weights(vit_weights)
("gate_up_proj", "up_proj", 1),
] # load mlp projector
params_dict = dict(self.named_parameters()) mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
loaded_params = set() mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in mlp_weights:
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): param = mlp_params_dict[name]
if key_to_modify in name: weight_loader = getattr(param, "weight_loader",
name = name.replace(key_to_modify, new_key) default_weight_loader)
use_default_weight_loading = False weight_loader(param, loaded_weight)
if "vision" not in name or self.vision_tower.shard_weight:
for (param_name, shard_name, # load llm backbone
shard_id) in stacked_params_mapping: llm_weights = filter_weights(llm_weights, "language_model")
if shard_name not in name: self.language_model.load_weights(llm_weights)
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with
# embed_token. To prevent errors, skip loading
# lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
use_default_weight_loading = True
else:
use_default_weight_loading = True
if use_default_weight_loading:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
logger.warning(
"Some weights are not initialized from checkpoints: %s",
unloaded_params)
...@@ -424,7 +424,9 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -424,7 +424,9 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
input_width=w, input_width=w,
input_height=h)) input_height=h))
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0] num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
......
import math
from array import array
from dataclasses import dataclass, fields
from itertools import tee
from typing import Iterable, List, Mapping, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk
from PIL import Image
from transformers import PretrainedConfig
from xformers.ops.fmha import memory_efficient_attention
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from .interfaces import SupportsMultiModal
from .utils import init_vllm_registered_model
def get_max_pixtral_image_tokens(ctx: InputContext):
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.instruct.mm_encoder
max_image_size = mm_encoder.mm_config.max_image_size
image_patch_size = mm_encoder.mm_config.image_patch_size
return ((max_image_size // image_patch_size)**2)
def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.instruct.mm_encoder
mm_config = ctx.model_config.multimodal_config
max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1)
# approximate image size
size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size)
image = Image.new("RGB", (size, size), color=0)
img_chunk = ImageChunk(image=image)
tokens = mm_encoder(img_chunk).tokens
token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE,
tokens)
seq_data = SequenceData(token_ids)
mm_data = {"image": max_num_images_per_request * [image]}
return seq_data, mm_data
def input_mapper_for_pixtral(ctx: InputContext,
data: object) -> MultiModalInputs:
"""Maps the input data to its MultiModalInputs (if any).
Args:
ctx: Context of the loaded model.
data: data potentially containing image/image embeddings to be mapped
to pixel_values in .forward() for a visual QWenLMHeadModel model.
Returns:
MultiModalInputs containing the stacked normalized images tensor or
image embeddings.
"""
# Early exit if we have provided an image to a language only Qwen model
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)
data_list = data if isinstance(data, list) else [data]
images = []
for image_data in data_list:
image = ImageChunk(image=image_data)
encoding = tokenizer.instruct.mm_encoder(image)
image = torch.from_numpy(encoding.image).to(device="cuda",
dtype=torch.float16)
images.append(image)
return MultiModalInputs({"images": images})
def merge_multimodal_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: Optional[List[torch.Tensor]],
image_id: int) -> torch.Tensor:
text_locations = input_ids != image_id
image_locations = input_ids == image_id
seq_len = input_ids.shape[0]
N_txt = text_locations.sum().item()
_, D_txt = inputs_embeds.shape
N_img, D_img = image_features.shape
assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal "
"to image features dim {D_img}")
assert (seq_len == N_txt +
N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img "
f"{(N_txt, N_img, image_locations.sum().item())}")
inputs_embeds[image_locations, :] = image_features
return inputs_embeds
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.multimodal_config = multimodal_config
dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
vision_args = {
key: value
for key, value in self.config.vision_config.to_dict().items()
if key in dataclass_fields
}
self.vision_args = VisionEncoderArgs(**vision_args)
# init MistralForCausalLM
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
self.vision_encoder = VisionTransformer(self.vision_args)
self.vision_language_adapter = VisionLanguageAdapter(
self.vision_args, dim=config.text_config.hidden_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
"""Run forward pass for pixtral.
TODO
"""
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vision_args.image_token_id)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
None,
inputs_embeds=inputs_embeds)
return hidden_states
def _parse_and_validate_image_input(
self,
images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor],
torch.Tensor]] = None
) -> Optional[List[torch.Tensor]]:
if images is None:
return None
if isinstance(images, torch.Tensor):
# always take last images
images = [images[-1][i] for i in range(images.size(1))]
elif isinstance(images, list):
# always take last images
images = [images[-1][i] for i in range(len(images[0]))]
return images
def _process_image_input(self,
image_input: List[torch.Tensor]) -> torch.Tensor:
return self.vision_language_adapter(self.vision_encoder(image_input))
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
return weight[0].startswith("vision_encoder")
def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
return weight[0].startswith("vision_language_adapter")
def is_vision_weights(weight: Tuple[str, torch.Tensor]):
return is_vision_encoder_weights(
weight) or is_vision_lang_adapter_weights(weight)
llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee(
weights, 3)
# llm
llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights)
self.language_model.load_weights(llm_weights)
# vision encoder
vision_encoder_weights = filter(is_vision_encoder_weights,
vision_encoder_weights)
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
for name, loaded_weight in vision_encoder_weights:
# cut 'vision_encoder.'
name = '.'.join(name.split(".")[1:])
param = vision_encoder_dict[name]
default_weight_loader(param, loaded_weight)
# adapter
vision_lang_adapter_weights = filter(is_vision_lang_adapter_weights,
vision_lang_adapter_weights)
vision_lang_adpter_dict = dict(
self.vision_language_adapter.named_parameters())
for name, loaded_weight in vision_lang_adapter_weights:
# cut 'vision_language_adapter.'
name = '.'.join(name.split(".")[1:])
param = vision_lang_adpter_dict[name]
default_weight_loader(param, loaded_weight)
# Vision encoder
@dataclass
class VisionEncoderArgs:
hidden_size: int
num_channels: int
image_size: int
patch_size: int
intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
rope_theta: float # for rope-2D
image_token_id: int
def _reshape_for_broadcast(freqs_cis: torch.Tensor,
x: torch.Tensor) -> torch.Tensor:
"""
freqs_cis: complex - (seq_len, head_dim / 2)
x: complex - (bsz, seq_len, head_dim / 2)
"""
ndim = x.ndim
assert ndim > 1
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
freqs_cis.shape,
(x.shape[1], x.shape[-1]),
)
shape = [
d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)
]
return freqs_cis.view(*shape)
def precompute_freqs_cis_2d(
dim: int,
height: int,
width: int,
theta: float,
) -> torch.Tensor:
"""
freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
to be indexed by (height, width) position tuples
"""
# (dim / 2) frequency bases
freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))
h = torch.arange(height, device=freqs.device)
w = torch.arange(width, device=freqs.device)
freqs_h = torch.outer(h, freqs[::2]).float()
freqs_w = torch.outer(w, freqs[1::2]).float()
freqs_2d = torch.cat(
[
freqs_h[:, None, :].repeat(1, width, 1),
freqs_w[None, :, :].repeat(height, 1, 1),
],
dim=-1,
)
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
def apply_rotary_emb_vit(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
assert freqs_cis.dtype == torch.complex64
freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class FeedForward(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
assert args.intermediate_size is not None
self.w1 = nn.Linear(args.hidden_size,
args.intermediate_size,
bias=False)
self.w2 = nn.Linear(args.intermediate_size,
args.hidden_size,
bias=False)
self.w3 = nn.Linear(args.hidden_size,
args.intermediate_size,
bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class Attention(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.args = args
assert not args.hidden_size % args.num_attention_heads
self.n_heads = args.num_attention_heads
self.head_dim = args.hidden_size // args.num_attention_heads
self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
def forward(
self,
x: torch.Tensor,
mask: BlockDiagonalMask,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
batch, patches, _ = x.shape
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = q.reshape(batch, patches, self.n_heads, self.head_dim)
k = k.reshape(batch, patches, self.n_heads, self.head_dim)
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
out = memory_efficient_attention(q, k, v, attn_bias=mask)
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
return self.wo(out)
class TransformerBlock(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.attention = Attention(args)
self.feed_forward = FeedForward(args)
self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5)
self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5)
def forward(
self,
x: torch.Tensor,
mask: BlockDiagonalMask,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(x),
mask=mask,
freqs_cis=freqs_cis)
h = x + r
r = self.feed_forward.forward(self.ffn_norm(h))
out = h + r
return out
class Transformer(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(args.num_hidden_layers):
self.layers.append(TransformerBlock(args))
def forward(
self,
x: torch.Tensor,
mask: BlockDiagonalMask,
freqs_cis: Optional[torch.Tensor],
) -> torch.Tensor:
for layer in self.layers:
x = layer(x, mask=mask, freqs_cis=freqs_cis)
return x
def position_meshgrid(patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor:
positions = torch.cat([
torch.stack(
torch.meshgrid(
torch.arange(p.shape[-2]),
torch.arange(p.shape[-1]),
indexing="ij",
),
dim=-1,
).reshape(-1, 2) for p in patch_embeds_list
])
return positions
class VisionTransformer(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.args = args
self.patch_conv = nn.Conv2d(
in_channels=args.num_channels,
out_channels=args.hidden_size,
kernel_size=args.patch_size,
stride=args.patch_size,
bias=False,
)
self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
self.transformer = Transformer(args)
head_dim = self.args.hidden_size // self.args.num_attention_heads
assert head_dim % 2 == 0, "ROPE requires even head_dim"
self._freqs_cis: Optional[torch.Tensor] = None
@property
def max_patches_per_side(self) -> int:
return self.args.image_size // self.args.patch_size
@property
def device(self) -> torch.device:
return next(self.parameters()).device
@property
def dtype(self) -> torch.device:
return next(self.parameters()).dtype
@property
def freqs_cis(self) -> torch.Tensor:
if self._freqs_cis is None:
self._freqs_cis = precompute_freqs_cis_2d(
dim=self.args.hidden_size // self.args.num_attention_heads,
height=self.max_patches_per_side,
width=self.max_patches_per_side,
theta=self.args.rope_theta,
)
if self._freqs_cis.device != self.device:
self._freqs_cis = self._freqs_cis.to(device=self.device)
return self._freqs_cis
def forward(
self,
images: List[torch.Tensor],
) -> torch.Tensor:
"""
Args:
images: list of N_img images of variable sizes,
each of shape (C, H, W)
Returns:
image_features: tensor of token features for
all tokens of all images of shape (N_toks, D)
"""
# pass images through initial convolution independently
patch_embeds_list = [
self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
]
# flatten to a single sequence
patch_embeds = torch.cat(
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings
positions = position_meshgrid(patch_embeds_list).to(self.device)
freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
# pass through Transformer with a block diagonal mask delimiting images
mask = BlockDiagonalMask.from_seqlens(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
# remove batch dimension of the single sequence
return out.squeeze(0)
class VisionLanguageAdapter(nn.Module):
def __init__(self, args: VisionEncoderArgs, dim: int):
super().__init__()
assert isinstance(args, VisionEncoderArgs)
self.w_in = nn.Linear(
args.hidden_size,
dim,
bias=True,
)
self.gelu = nn.GELU()
self.w_out = nn.Linear(dim, dim, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_out(self.gelu(self.w_in(x)))
...@@ -4,41 +4,408 @@ ...@@ -4,41 +4,408 @@
# Copyright (c) Alibaba Cloud. # Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
"""Inference-only QWen model compatible with HuggingFace weights.""" """Inference-only QWen model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple
import math
import re
from array import array
from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Tuple, TypedDict, Union)
import numpy as np
import torch import torch
from PIL import Image
from torch import nn from torch import nn
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from transformers import PretrainedConfig from transformers import PretrainedConfig
import os import os
import re import re
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils import print_warning_once from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from .utils import flatten_bn, is_pp_missing_parameter, make_layers
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from .utils import is_pp_missing_parameter, make_layers from .utils import is_pp_missing_parameter, make_layers
logger = init_logger(__name__)
# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad;
# for the time being, these tags are not considered as special at encoding
# time. This may change as VLLMs multimodal API changes in the future.
IMG_START = "<img>"
IMG_END = "</img>"
IMG_PAD = "<imgpad>"
# Image context is fixed at 256 for all images
MAX_QWEN_IMG_TOKENS = 256
# Image normalization params
CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
class QwenImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""
Shape: `(batch_size * num_images, 3, image_size, image_size)`
Note that image_size is the value in the vision config to which we resize
the image to in the normalization transform. Currently multi-image support
can only be leveraged by passing image embeddings directly.
"""
class QwenImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, 256, hidden_size)`
`hidden_size` must match the hidden size of the language model backbone
and is stored in the visual config of the model if we have one.
"""
QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs]
class VisualAttention(nn.Module):
"""self-attention layer class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
kdim: Optional[int] = None,
vdim: Optional[int] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim \
and self.vdim == embed_dim
self.num_heads = num_heads
# Per attention head and per partition values.
assert embed_dim % num_heads == 0
self.hidden_size_per_attention_head = embed_dim // num_heads
self.num_attention_heads_per_partition = num_heads
self.hidden_size_per_partition = embed_dim
# Strided linear layer.
assert self._qkv_same_embed_dim, \
'Visual Attention implementation only supports self-attention'
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# query/key/value: [sq, b, h]
sq, b, _ = x.size()
mixed_x_layer = self.in_proj(x)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
query_layer, key_layer, value_layer = mixed_x_layer.split(
self.hidden_size_per_attention_head, dim=-1)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(
sq, b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(
sq, b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
q_scaled = query_layer / self.norm_factor
if attn_mask is not None:
attention_probs = torch.baddbmm(attn_mask, q_scaled,
key_layer.transpose(-2, -1))
else:
attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
attention_probs = attention_probs.softmax(dim=-1)
value_layer = value_layer.view(
sq, b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer)
# change view [b, np, sq, hn]
context_layer = context_layer.view(
b, self.num_attention_heads_per_partition, sq,
self.hidden_size_per_attention_head)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.out_proj(context_layer)
return output
class QwenVMLP(nn.Module):
"""MLP for the visual component of the Qwen model."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.c_fc = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=True,
quant_config=quant_config)
self.act_fn = get_act_fn("gelu", quant_config, intermediate_size)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
quant_config=quant_config,
)
def forward(self, x):
x, _ = self.c_fc(x)
x = self.act_fn(x)
x, _ = self.c_proj(x)
return x
class VisualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
norm_layer: Callable = nn.LayerNorm,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.attn = VisualAttention(d_model, n_head)
self.mlp = QwenVMLP(
hidden_size=d_model,
intermediate_size=mlp_width,
quant_config=quant_config,
)
def attention(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
return self.attn(x, attn_mask=attn_mask)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
x = x + self.mlp(self.ln_2(x))
return x
class TransformerBlock(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
norm_layer: Callable = nn.LayerNorm,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList([
VisualAttentionBlock(width,
heads,
mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config)
for _ in range(layers)
])
def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.weight.dtype
def get_cast_device(self) -> torch.device:
return self.resblocks[0].mlp.c_fc.weight.device
def forward(self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
for r in self.resblocks:
x = r(x, attn_mask=attn_mask)
return x
class VisionTransformer(nn.Module):
def __init__(self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
n_queries: int = 256,
output_dim: int = 512,
image_start_id: int = 151857,
quant_config: Optional[QuantizationConfig] = None,
**kwargs):
super().__init__()
image_height, image_width = self.image_size = (image_size, image_size)
patch_height, patch_width = self.patch_size = (patch_size, patch_size)
self.grid_size = (image_height // patch_height,
image_width // patch_width)
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False)
# class embeddings and positional embeddings
scale = width**-0.5
self.positional_embedding = nn.Parameter(scale *
torch.randn(256, width))
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.ln_pre = norm_layer(width)
self.transformer = TransformerBlock(width,
layers,
heads,
mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config)
self.attn_pool = Resampler2(
grid_size=int(math.sqrt(n_queries)),
embed_dim=output_dim,
num_heads=output_dim // 128,
kv_dim=width,
norm_layer=norm_layer,
adaptive=False,
do_post_projection=False,
).to(
device=self.positional_embedding.device,
dtype=self.positional_embedding.dtype,
)
self.ln_post = norm_layer(output_dim)
self.proj = nn.Parameter(
(output_dim**-0.5) * torch.randn(output_dim, output_dim))
self.image_start_id = image_start_id
self.image_end_id = image_start_id + 1
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.to(
dtype=self.transformer.get_cast_dtype(),
device=self.transformer.get_cast_device(),
)
# to patches
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = x + get_abs_pos(self.positional_embedding, int(math.sqrt(
x.size(1))))
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.attn_pool(x)
x = self.ln_post(x)
x = x @ self.proj
return x
def get_image_positions(self,
input_ids: torch.Tensor) -> Optional[torch.Tensor]:
"""Given the input IDs, extracts start/stop points corresponding to
images.
args:
Returns:
Optional torch tensor corresponding to start/stop pairs of images.
"""
if torch.any(input_ids == self.image_start_id):
bos_pos = torch.where(input_ids == self.image_start_id)
eos_pos = torch.where(input_ids == self.image_end_id)
return torch.stack((bos_pos[0], eos_pos[0]), dim=1)
return None
class QWenMLP(nn.Module): class QWenMLP(nn.Module):
"""MLP for the language component of the Qwen model, which contains a
MergedColumnParallelLinear merging 2 outputs via silu activation."""
def __init__( def __init__(
self, self,
...@@ -61,7 +428,7 @@ class QWenMLP(nn.Module): ...@@ -61,7 +428,7 @@ class QWenMLP(nn.Module):
"Only silu is supported for now.") "Only silu is supported for now.")
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.c_proj(x) x, _ = self.c_proj(x)
...@@ -215,6 +582,9 @@ class QWenModel(nn.Module): ...@@ -215,6 +582,9 @@ class QWenModel(nn.Module):
lambda prefix: QWenBlock(config, cache_config, quant_config), lambda prefix: QWenBlock(config, cache_config, quant_config),
prefix=f"{prefix}.h") prefix=f"{prefix}.h")
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.visual = VisionTransformer(**config.visual,
quant_config=quant_config) if hasattr(
config, "visual") else None
def forward( def forward(
self, self,
...@@ -223,9 +593,33 @@ class QWenModel(nn.Module): ...@@ -223,9 +593,33 @@ class QWenModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
pixel_values: Optional[QwenImageInputs],
) -> torch.Tensor: ) -> torch.Tensor:
img_pos = None
# If pixel / visual embeddings are provided, this is a visual model
if pixel_values is not None and self.visual is not None:
if pixel_values["type"] != "image_embeds":
image_embeds = self.visual(pixel_values["data"])
else:
image_embeds = pixel_values["data"]
# features should be of shape (# images, 256, hidden_dim)
img_pos = self.visual.get_image_positions(input_ids)
if isinstance(
img_pos,
np.ndarray) and img_pos.shape[0] != image_embeds.shape[0]:
raise ValueError(
f"Number of placeholders: {img_pos.shape[0]} "
f"does not match number of images {image_embeds.shape[0]}."
)
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.wte(input_ids) hidden_states = self.wte(input_ids)
# Merge the image embeddings into the hidden states if actually have
# visual features and the corresponding image tokens
if img_pos is not None:
for idx, (img_bos, img_eos) in enumerate(img_pos):
hidden_states[img_bos + 1:img_eos] = image_embeds[idx]
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -249,16 +643,241 @@ class QWenModel(nn.Module): ...@@ -249,16 +643,241 @@ class QWenModel(nn.Module):
return hidden_states return hidden_states
class QWenLMHeadModel(nn.Module): def get_image_text(image_num: int, padding: bool) -> str:
"""Retrieves a placeholder text that when tokenized, will be expanded with
image pads.
Args:
image_num: The number of the image that we want a text prompt for.
Images should be indexed starting at 1.
padding: Whether or not padding should be manually added.
Returns:
Text placeholder prompt for the image being considered.
"""
image_start = f"Picture {image_num}: {IMG_START}"
image_end = f"{IMG_END}\n"
if not padding:
return f"{image_start}{image_end}"
return f"{image_start}{MAX_QWEN_IMG_TOKENS * IMG_PAD}{image_end}"
def input_processor_for_qwen(ctx: InputContext,
llm_inputs: LLMInputs) -> LLMInputs:
"""Processes the inputs, which may or may not be multimodal.
Multimodal inputs will only be processed if the model has a "visual"
component in its model config, otherwise they'll be ignored.
Args:
ctx: Context of the loaded model.
llm_inputs: LLM inputs which may have a multi_modal_data attribute.
Returns:
If the model is language only or not multimodal inputs were provided,
returns llm_inputs unmodified. Otherwise, processes the multimodal
images / image embeddings and adds the fixed-length image placeholders.
"""
multi_modal_data = llm_inputs.get("multi_modal_data")
# Only process images if we have multimodal data and a visual config
hf_config = ctx.get_hf_config()
if (multi_modal_data is None or "image" not in multi_modal_data
or not hasattr(hf_config, "visual")):
return llm_inputs
prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
image_data = multi_modal_data["image"]
if isinstance(image_data, torch.Tensor):
num_dims = len(image_data.shape)
if num_dims < 2 or num_dims > 3:
raise ValueError(
f"Expected img embeds to be have 3 dimensions, got {num_dims}")
num_images = 1 if num_dims == 2 else image_data.shape[0]
else:
# TODO - handle multiple image inputs once the API is solidified
num_images = 1
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
# Drops anything between <img>/</img> tags; encoding with the tokenizer
# will automatically add the image pads for the context.
new_prompt, num_matched_images = re.subn(
r"(Picture \d*: <img>).*?(<\/img>\n)",
r"\1\2",
prompt,
)
if num_matched_images != num_images:
logger.warning(
"Number of matched image placeholders %s doesn't match the number "
"of expected images %s; check your placeholder formatting.",
num_matched_images, num_images)
new_prompt_token_ids = tokenizer.encode(new_prompt)
return LLMInputs(prompt=new_prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data)
def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
"""Maps the input data to its MultiModalInputs (if any).
Args:
ctx: Context of the loaded model.
data: data potentially containing image/image embeddings to be mapped
to pixel_values in .forward() for a visual QWenLMHeadModel model.
Returns:
MultiModalInputs containing the stacked normalized images tensor or
image embeddings.
"""
# Early exit if we have provided an image to a language only Qwen model
hf_config = ctx.get_hf_config()
if not hasattr(hf_config, "visual"):
logger.warning(
"Images were provided but this model has no visual config; "
"multimodal inputs will not be forwarded to the model.")
return MultiModalInputs()
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
image_pair_tok = tokenizer.encode(IMG_START + IMG_END,
add_special_tokens=False,
return_tensors="pt").squeeze()
image_start_id = image_pair_tok[0]
image_end_id = image_pair_tok[-1]
if (image_start_id + 1) != image_end_id:
raise ValueError(
f"Found image end ID {image_end_id}, but expected {IMG_START} + 1")
if len(image_pair_tok) != (MAX_QWEN_IMG_TOKENS + 2):
raise ValueError(
f"Expected image context length of {MAX_QWEN_IMG_TOKENS}, "
f"but got {image_pair_tok - 2}")
hf_config = ctx.get_hf_config()
image_size = hf_config.visual["image_size"]
img_emb_size = hf_config.visual["output_dim"]
if isinstance(data, torch.Tensor):
# It's expected that our values have already been processed
# by the visual transformer; shape is expected to be:
# (# images, 256, hidden_size)
if len(data.shape) == 2:
# Assume only one image embed was provided; unsqueeze the extra dim
data = data.unsqueeze(0)
if len(data.shape) != 3 or data.shape[
1] != MAX_QWEN_IMG_TOKENS or data.shape[2] != img_emb_size:
raise ValueError(
"Expected image embeds to be a tensor of shape"
f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but "
f"received shape [{data.shape}]")
pixel_values = data
else:
transform = build_normalization_transform(image_size)
# TODO - handle multiple image inputs once the API is solidified
transformed_images = [transform(data)]
pixel_values = torch.stack(transformed_images, dim=0)
return MultiModalInputs({"pixel_values": pixel_values})
def build_normalization_transform(image_size: int) -> transforms.Compose:
"""Builds a normalization transform which can be applied to one or
more input images from which we want to extract visual features.
Args:
image_size: size of the image to be processed for visual embeddings.
Returns:
Callable transform for normalizing and resizing one RGB image.
"""
return transforms.Compose([
transforms.Resize((image_size, image_size),
interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
def dummy_data_for_qwen(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Tuple[SequenceData, Optional[Dict]]:
"""Build dummy data for warming up Qwen models; this will only contain text
matching the defaults for VLLM unless the model has a visual config.
Args:
ctx: Context of the loaded model.
seq_len: Number of tokens in the text sequence.
mm_counts: multimodal data counts.
Returns:
Tuple containing sequential and multimodal data.
"""
hf_config = ctx.get_hf_config()
# The presence of a visual config indicates this is a multimodal model.
# If we don't have it, the model is considered an LLM for warmup purposes.
if not hasattr(hf_config, "visual"):
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len))
mm_data = None
return seq_data, mm_data
# We have a visual component - use images to warm up
num_images = mm_counts["image"]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
# Build the image prompts with no imgpads; the tokenizer will add img pads
image_prompt = ''.join(
[get_image_text(idx, False) for idx in range(1, num_images + 1)])
toks = tokenizer.encode(image_prompt, add_special_tokens=False)
# Make sure we actually get the fixed context size per tok padding
num_pads = toks.count(tokenizer.encode(IMG_PAD)[0])
if num_pads != (num_images * MAX_QWEN_IMG_TOKENS):
raise ValueError(
f"Tokenized dummy data should encode {MAX_QWEN_IMG_TOKENS} pads"
f" per image, but got {num_pads} pads for {num_images} image(s)"
" in total. Are you using a qwen tokenizer?")
# Ensure the number of tokens is at minimum the sequence length provided
if len(toks) < seq_len:
toks += [0] * (seq_len - len(toks))
# Build the input images; width/height doesn't actually matter here since
# the data will get resized and the # of tokens per image is constant
image = Image.new("RGB", (224, 224), color=0)
mm_data = {"image": image if num_images == 1 else [image] * num_images}
return SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, toks)), mm_data
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
class QWenLMHeadModel(nn.Module, SupportsMultiModal):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.multimodal_config = multimodal_config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = QWenModel(config, cache_config, quant_config) self.transformer = QWenModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
...@@ -278,16 +897,47 @@ class QWenLMHeadModel(nn.Module): ...@@ -278,16 +897,47 @@ class QWenLMHeadModel(nn.Module):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward( def _get_image_input_type(
self, self,
input_ids: torch.Tensor, pixel_values: Optional[torch.Tensor]) -> Optional[QwenImageInputs]:
positions: torch.Tensor, """Determines if the provided pixel_values are normalized pixel values
kv_caches: List[torch.Tensor], or image embeddings.
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, Args:
) -> torch.Tensor: pixel_values: Optional data to processed into visual embeddings.
Returns:
None of the QwenImageInputs type used to determine whether or not
the visual transformer needs to process the pixel_values.
"""
if pixel_values is not None and self.transformer.visual is not None:
pixel_values = flatten_bn(pixel_values)
if len(pixel_values.shape) == 3 and pixel_values.shape[
1] == MAX_QWEN_IMG_TOKENS and pixel_values.shape[
2] == self.config.visual["output_dim"]:
return QwenImageEmbeddingInputs(
type="image_embeds",
data=pixel_values,
)
else:
# If we have the wrong shape, assume we still need to process
return QwenImagePixelInputs(
type="pixel_values",
data=pixel_values,
)
return None
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor:
pixel_values = self._get_image_input_type(pixel_values)
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
pixel_values)
return hidden_states return hidden_states
def make_empty_intermediate_tensors( def make_empty_intermediate_tensors(
...@@ -349,15 +999,6 @@ class QWenLMHeadModel(nn.Module): ...@@ -349,15 +999,6 @@ class QWenLMHeadModel(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Skip loading visual weights to support Qwen-VL models
# in cases with text-only inputs
# TODO: add support for Qwen-VL
if (name not in params_dict
and name.startswith("transformer.visual.")):
print_warning_once(
"Only text inputs are allowed. Images won't be handled "
"until Qwen-VL models are fully supported.")
continue
# Skip layers on other devices. # Skip layers on other devices.
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue
......
...@@ -469,7 +469,8 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -469,7 +469,8 @@ class Qwen2MoeForCausalLM(nn.Module):
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue continue
# Skip layers on other devices. # Skip layers on other devices.
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
...@@ -490,6 +491,10 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -490,6 +491,10 @@ class Qwen2MoeForCausalLM(nn.Module):
# Skip layers on other devices. # Skip layers on other devices.
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, weight_loader(param,
...@@ -500,7 +505,8 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -500,7 +505,8 @@ class Qwen2MoeForCausalLM(nn.Module):
break break
else: else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue continue
# Skip layers on other devices. # Skip layers on other devices.
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
......
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from array import array
from functools import lru_cache, partial
from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict,
Union)
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from PIL import Image
from transformers import Qwen2VLConfig
from transformers.image_utils import (get_image_size,
infer_channel_dimension_format,
to_numpy_array)
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
Qwen2VLVisionConfig)
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
make_batched_images, make_batched_videos, smart_resize)
import vllm.envs as envs
from vllm.attention import AttentionMetadata
from vllm.attention.selector import (_Backend, backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalInputs)
from vllm.multimodal.base import MultiModalData
from vllm.multimodal.image import cached_get_image_processor
from vllm.platforms import current_platform
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.transformers_utils.processor import get_processor
logger = init_logger(__name__)
# === Vision Inputs === #
class Qwen2VLImageInputs(TypedDict):
pixel_values: torch.Tensor
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
"""
image_grid_thw: torch.Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
class Qwen2VLVideoInputs(TypedDict):
pixel_values_videos: torch.Tensor
"""Shape:
`(num_patches,
num_channels * temporal_patch_size * patch_size * patch_size)`
"""
video_grid_thw: torch.Tensor
"""Shape: `(num_videos, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
# === Vision Encoder === #
class Qwen2VisionMLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int = None,
act_layer: Type[nn.Module] = QuickGELU,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.fc1 = ColumnParallelLinear(in_features,
hidden_features,
quant_config=quant_config)
self.act = act_layer()
self.fc2 = RowParallelLinear(hidden_features,
in_features,
quant_config=quant_config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_parallel, _ = self.fc1(x)
x_parallel = self.act(x_parallel)
x, _ = self.fc2(x_parallel)
return x
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1),
"... d two -> ... (d two)",
two=2)
def apply_rotary_emb_torch(x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
interleaved: bool = False) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos,
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(
sin,
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
return torch.cat(
[
x[..., :ro_dim] * cos +
rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor,
freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
return output
class Qwen2VisionAttention(nn.Module):
def __init__(
self,
embed_dim: Optional[int] = None,
num_heads: Optional[int] = None,
projection_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, world_size)
self.qkv = ColumnParallelLinear(input_size=embed_dim,
output_size=3 * projection_size,
quant_config=quant_config)
self.proj = RowParallelLinear(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config)
# Detect attention implementation.
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.get_device_capability()[0] >= 8
if device_available:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
self._use_flash_attn = True
else:
logger.warning(
"Current Qwen2-VL implementation has a bug with "
"`vllm-flash-attn` inside vision module, so we use "
"xformers backend instead. You can run `pip install "
"flash-attn to use flash-attention backend.")
self._use_flash_attn = False
else:
self._use_flash_attn = False
else:
if selected_backend == _Backend.FLASH_ATTN:
self._use_flash_attn = True
elif selected_backend == _Backend.XFORMERS:
self._use_flash_attn = False
else:
raise RuntimeError(
f"Qwen2-VL does not support {selected_backend} backend now."
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
new_x_shape = x.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
x = x.view(*new_x_shape)
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
batch_size = q.shape[1]
q, k, v = [
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
]
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
if self._use_flash_attn:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
from flash_attn import flash_attn_varlen_func
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
output = flash_attn_varlen_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0,
causal=False)
context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
else:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
kv_seqlen=None)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None)
context_layer = rearrange(context_layer,
"b s h d -> s b (h d)").contiguous()
output, _ = self.proj(context_layer)
return output
class Qwen2VisionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float,
act_layer: Type[nn.Module] = QuickGELU,
norm_layer: Type[nn.Module] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.attn = Qwen2VisionAttention(embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config)
self.mlp = Qwen2VisionMLP(dim,
mlp_hidden_dim,
act_layer=act_layer,
quant_config=quant_config)
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb)
x = x + self.mlp(self.norm2(x))
return x
class Qwen2VisionPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_chans: int = 3,
embed_dim: int = 1152,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.embed_dim = embed_dim
kernel_size = [temporal_patch_size, patch_size, patch_size]
self.proj = nn.Conv3d(in_chans,
embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
self.patch_size)
x = self.proj(x).view(L, self.embed_dim)
return x
class Qwen2VisionPatchMerger(nn.Module):
def __init__(
self,
d_model: int,
context_dim: int,
norm_layer: Type[nn.Module] = None,
spatial_merge_size: int = 2,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.ln_q = norm_layer(context_dim)
self.mlp = nn.ModuleList([
ColumnParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config),
nn.GELU(),
RowParallelLinear(self.hidden_size,
d_model,
bias=True,
quant_config=quant_config),
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.ln_q(x)
x = x.view(-1, self.hidden_size)
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
x_parallel, _ = mlp_fc1(x)
x_parallel = mlp_act(x_parallel)
out, _ = mlp_fc2(x_parallel)
return out
class Qwen2VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta
**(torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._freqs_cached = None
def update_freqs_cache(self, seqlen: int) -> None:
if seqlen > self._seq_len_cached:
seqlen *= 2
self._seq_len_cached = seqlen
self.inv_freq = 1.0 / (self.theta**(torch.arange(
0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device)
/ self.dim))
seq = torch.arange(seqlen,
device=self.inv_freq.device,
dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
self._freqs_cached = freqs
def forward(self, seqlen: int) -> torch.Tensor:
self.update_freqs_cache(seqlen)
return self._freqs_cached[:seqlen]
class Qwen2VisionTransformer(nn.Module):
def __init__(
self,
vision_config: Qwen2VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
patch_size: int = vision_config.patch_size
temporal_patch_size: int = vision_config.temporal_patch_size
spatial_merge_size: int = vision_config.spatial_merge_size
in_chans: int = vision_config.in_chans
hidden_size: int = vision_config.hidden_size
embed_dim: int = vision_config.embed_dim
depth: int = vision_config.depth
num_heads: int = vision_config.num_heads
mlp_ratio: float = vision_config.mlp_ratio
self.spatial_merge_size = spatial_merge_size
self.patch_embed = Qwen2VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = embed_dim // num_heads
self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList([
Qwen2VisionBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config,
) for _ in range(depth)
])
self.merger = Qwen2VisionPatchMerger(
d_model=hidden_size,
context_dim=embed_dim,
norm_layer=norm_layer,
quant_config=quant_config,
)
@property
def dtype(self) -> torch.dtype:
return self.blocks[0].mlp.fc2.weight.dtype
@property
def device(self) -> torch.device:
return self.blocks[0].mlp.fc2.weight.device
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
).permute(0, 2, 1, 3).flatten()
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
).permute(0, 2, 1, 3).flatten()
pos_ids.append(
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor,
) -> torch.Tensor:
# patchify
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)
# compute position embedding
rotary_pos_emb = self.rot_pos_emb(grid_thw)
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
grid_thw[:, 0]).cumsum(
dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# transformers
x = x.unsqueeze(1)
for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
# adapter
x = self.merger(x)
return x
# === Vision input helpers === #
cached_get_processor = lru_cache(get_processor)
def mm_input_mapper_for_qwen2_vl(
ctx: InputContext,
data: MultiModalData[object],
data_type_key: str,
) -> MultiModalInputs:
"""Input mapper for Qwen2-VL."""
model_config = ctx.model_config
image_processor = cached_get_image_processor(
model_config.model, trust_remote_code=model_config.trust_remote_code)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
images = None
videos = None
if data_type_key == "image":
images = data
else:
assert data_type_key == "video"
videos = data
try:
batch_data = image_processor \
.preprocess(images=images, videos=videos, return_tensors="pt") \
.data
except Exception:
logger.error("Failed to process image (%s)", data)
raise
return MultiModalInputs(batch_data)
image_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl,
data_type_key="image")
video_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl,
data_type_key="video")
def _get_vision_info(
image_processor,
height: int,
width: int,
min_pixels: int,
max_pixels: int,
do_resize: bool = True,
data_type_key: str = "image",
mm_count: int = 1,
):
"""Get information (resized height / width and number of vision tokens)
of input image / video frame."""
if do_resize:
resized_height, resized_width = smart_resize(
height=height,
width=width,
factor=image_processor.patch_size * image_processor.merge_size,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
else:
resized_height, resized_width = height, width
if data_type_key == "image":
grid_t = mm_count
else:
assert data_type_key == "video"
grid_t = max(mm_count // image_processor.temporal_patch_size, 1)
grid_h = resized_height // image_processor.patch_size
grid_w = resized_width // image_processor.patch_size
vision_tokens = grid_t * grid_h * grid_w
llm_num_vision_tokens = (vision_tokens // image_processor.merge_size //
image_processor.merge_size)
return resized_height, resized_width, llm_num_vision_tokens
def _get_max_image_info(
image_processor,
data_type_key: str = "image",
mm_count: int = 1,
):
return _get_vision_info(
image_processor,
height=9999999,
width=9999999,
# Limit min / max pixels.
min_pixels=max(image_processor.min_pixels, 28 * 28),
max_pixels=min(image_processor.max_pixels, 1280 * 28 * 28),
data_type_key=data_type_key,
mm_count=mm_count,
)
def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int:
image_processor = cached_get_image_processor(ctx.model_config.model)
max_resized_height, max_resized_width, max_llm_image_tokens = \
_get_max_image_info(image_processor, data_type_key=data_type_key,
mm_count=1)
return max_llm_image_tokens
get_max_qwen2_vl_image_tokens = partial(get_max_qwen2_vl_mm_tokens,
data_type_key="image")
get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens,
data_type_key="video")
def dummy_data_for_qwen2_vl(
ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
image_processor = cached_get_image_processor(ctx.model_config.model)
num_images = mm_counts["image"]
max_resized_height, max_resized_width, max_llm_image_tokens = \
_get_max_image_info(image_processor, data_type_key="image",
mm_count=num_images)
if seq_len - max_llm_image_tokens - 2 < 0:
raise RuntimeError(
f"Qwen2-VL cannot process {num_images} images in a prompt, "
"please increase max_model_len or reduce image limit by "
"--limit-mm-per-prompt.")
# Check video counts.
num_videos = mm_counts["video"]
max_resized_height, max_resized_width, max_llm_video_tokens = \
_get_max_image_info(image_processor, data_type_key="video",
mm_count=num_videos)
if seq_len - max_llm_video_tokens - 2 < 0:
raise RuntimeError(
f"Qwen2-VL cannot process {num_images} videos in a prompt, "
"please increase max_model_len or reduce video limit by "
"--limit-mm-per-prompt.")
hf_config = ctx.get_hf_config(Qwen2VLConfig)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[hf_config.vision_start_token_id])
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[hf_config.image_token_id]) * max_llm_image_tokens
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[hf_config.vision_end_token_id])
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - max_llm_image_tokens - 2)
dummy_seqdata = SequenceData(token_ids)
dummy_image = Image.new("RGB", (max_resized_width, max_resized_height),
color=0)
return dummy_seqdata, {
"image": dummy_image if num_images == 1 else [dummy_image] * num_images
}
def _get_llm_num_vision_tokens(
mm_inputs: list,
data_type_key: str,
image_processor,
):
"""Get number of vision tokens of multimodal inputs.
This method is derived from `transformers.models.qwen2_vl.
image_processing_qwen2_vl.Qwen2VLImageProcessor._preprocess`.
"""
image = to_numpy_array(mm_inputs[0])
input_data_format = infer_channel_dimension_format(image)
height, width = get_image_size(image, channel_dim=input_data_format)
_, _, llm_num_vision_tokens = _get_vision_info(
image_processor,
height=height,
width=width,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
do_resize=image_processor.do_resize,
data_type_key=data_type_key,
mm_count=len(mm_inputs),
)
return llm_num_vision_tokens
def input_processor_for_qwen2_vl(ctx: InputContext,
llm_inputs: LLMInputs) -> LLMInputs:
multi_modal_data = llm_inputs.get("multi_modal_data", None)
if multi_modal_data is None:
return llm_inputs
image_inputs = multi_modal_data.get("image", None)
video_inputs = multi_modal_data.get("video", None)
processor = cached_get_processor(ctx.model_config.model)
image_processor = processor.image_processor
hf_config = ctx.get_hf_config(Qwen2VLConfig)
# To avoid redundant processing of vision objects (resize, rescale, etc.),
# we extract code of calculating number of vision tokens from
# `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`.
#
# The following code is equivalent to:
# prompt = llm_inputs["prompt"]
# inputs = processor(text=[prompt],
# images=image_inputs,
# videos=video_inputs,
# padding=True,
# return_tensors="pt")
# prompt_token_ids = inputs["input_ids"][0].tolist()
prompt_token_ids = llm_inputs.get("prompt_token_ids", None)
if prompt_token_ids is None:
prompt = llm_inputs["prompt"]
prompt_token_ids = processor.tokenizer(
prompt,
padding=True,
return_tensors=None,
)["input_ids"]
# Expand image pad tokens.
if image_inputs is not None:
image_indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.image_token_id
]
image_inputs = make_batched_images(image_inputs)
assert len(image_indices) == len(image_inputs)
prompt_token_ids_with_image = []
for image_cnt, image in enumerate(image_inputs):
num_image_tokens = _get_llm_num_vision_tokens(
[image],
data_type_key="image",
image_processor=image_processor,
)
if image_cnt == 0:
non_image_tokens = prompt_token_ids[:image_indices[image_cnt]]
else:
non_image_tokens = prompt_token_ids[image_indices[image_cnt -
1] +
1:image_indices[image_cnt]]
prompt_token_ids_with_image.extend(non_image_tokens)
prompt_token_ids_with_image.extend(
hf_config.image_token_id for _ in range(num_image_tokens))
prompt_token_ids_with_image.extend(prompt_token_ids[image_indices[-1] +
1:])
prompt_token_ids = prompt_token_ids_with_image
# Expand video pad tokens.
if video_inputs is not None:
video_indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.video_token_id
]
video_inputs = make_batched_videos(video_inputs)
assert len(video_indices) == len(video_inputs)
prompt_token_ids_with_video = []
for video_cnt, video in enumerate(video_inputs):
num_video_tokens = _get_llm_num_vision_tokens(
video,
data_type_key="video",
image_processor=image_processor,
)
if video_cnt == 0:
non_video_tokens = prompt_token_ids[:video_indices[video_cnt]]
else:
non_video_tokens = prompt_token_ids[video_indices[video_cnt -
1] +
1:video_indices[video_cnt]]
prompt_token_ids_with_video.extend(non_video_tokens)
prompt_token_ids_with_video.extend(
hf_config.video_token_id for _ in range(num_video_tokens))
prompt_token_ids_with_video.extend(prompt_token_ids[video_indices[-1] +
1:])
prompt_token_ids = prompt_token_ids_with_video
return LLMInputs(
prompt_token_ids=prompt_token_ids,
prompt=llm_inputs["prompt"],
multi_modal_data=multi_modal_data,
)
@MULTIMODAL_REGISTRY.register_image_input_mapper(
image_input_mapper_for_qwen2_vl)
@MULTIMODAL_REGISTRY.register_input_mapper("video",
video_input_mapper_for_qwen2_vl)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"video", get_max_qwen2_vl_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: Qwen2VLConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
assert not cache_config.enable_prefix_caching, \
"Qwen2-VL currently does not support prefix caching"
self.config = config
self.multimodal_config = multimodal_config
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
# NOTE: Qwen2-VL vision encoder does not support any
# quantization method now.
quant_config=None,
)
self.model = Qwen2Model(config, cache_config, quant_config)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def _validate_and_reshape_mm_tensor(self,
mm_input: Union[torch.Tensor,
List[torch.Tensor]],
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim}")
return torch.concat(list(mm_input))
else:
return torch.concat(mm_input)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Qwen2VLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)
if pixel_values is None:
return None
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")
return Qwen2VLImageInputs(pixel_values=pixel_values,
image_grid_thw=image_grid_thw)
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
if pixel_values_videos is None:
return None
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos, "video pixel values")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return Qwen2VLVideoInputs(
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
)
def _process_image_input(self,
image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values,
grid_thw=image_input["image_grid_thw"])
return image_embeds
def _process_video_input(self,
video_input: Qwen2VLVideoInputs) -> torch.Tensor:
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
video_embeds = self.visual(pixel_values_videos,
grid_thw=video_input["video_grid_thw"])
return video_embeds
def _merge_multimodal_embeddings(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: torch.Tensor,
placeholder_token_id: int,
) -> torch.Tensor:
mask = (input_ids == placeholder_token_id)
inputs_embeds[mask, :] = multimodal_embeddings
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
"""Run forward pass for Qwen2-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
pixel_values_videos: Pixel values of videos to be fed to a model.
`None` if no videos are passed.
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed.
"""
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)
if image_input is None and video_input is None:
inputs_embeds = None
else:
if getattr(self.config, "rope_scaling", {}).get("type",
None) == "mrope":
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}")
inputs_embeds = self.model.embed_tokens(input_ids)
if image_input is not None:
image_embeds = self._process_image_input(image_input)
inputs_embeds = self._merge_multimodal_embeddings(
input_ids,
inputs_embeds,
image_embeds,
placeholder_token_id=self.config.image_token_id,
)
if video_input is not None:
video_embeds = self._process_video_input(video_input)
inputs_embeds = self._merge_multimodal_embeddings(
input_ids,
inputs_embeds,
video_embeds,
placeholder_token_id=self.config.video_token_id,
)
input_ids = None
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "visual" in name and "qkv.weight" in name:
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size,
visual_embed_dim)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
elif "visual" in name and "qkv.bias" in name:
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
try:
param = params_dict[name]
except KeyError:
print(params_dict.keys())
raise
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -110,7 +110,7 @@ def input_processor_for_siglip( ...@@ -110,7 +110,7 @@ def input_processor_for_siglip(
if isinstance(image_data, Image.Image): if isinstance(image_data, Image.Image):
image_feature_size = get_siglip_image_feature_size(hf_config) image_feature_size = get_siglip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0] num_images, image_feature_size, hidden_size = image_data.shape
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
else: else:
...@@ -443,27 +443,26 @@ class SiglipVisionTransformer(nn.Module): ...@@ -443,27 +443,26 @@ class SiglipVisionTransformer(nn.Module):
self.config = config self.config = config
embed_dim = config.hidden_size embed_dim = config.hidden_size
if (num_hidden_layers_override is None
or num_hidden_layers_override == config.num_hidden_layers):
self.need_post_layernorm = True
elif num_hidden_layers_override > config.num_hidden_layers:
raise ValueError(
"num_hidden_layers_override cannot be greater than "
"num_hidden_layers")
else:
self.need_post_layernorm = False
self.embeddings = SiglipVisionEmbeddings(config) self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder( self.encoder = SiglipEncoder(
config, config,
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
) )
if self.need_post_layernorm:
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {config.num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
elif len(self.encoder.layers) == config.num_hidden_layers:
self.post_layernorm = nn.LayerNorm(embed_dim, self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
else: else:
self.post_layernorm = nn.Identity() # post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None
self.use_head = (True if not hasattr(config, "vision_use_head") else self.use_head = (True if not hasattr(config, "vision_use_head") else
config.vision_use_head) config.vision_use_head)
if self.use_head: if self.use_head:
...@@ -482,6 +481,9 @@ class SiglipVisionTransformer(nn.Module): ...@@ -482,6 +481,9 @@ class SiglipVisionTransformer(nn.Module):
encoder_outputs = self.encoder(inputs_embeds=hidden_states) encoder_outputs = self.encoder(inputs_embeds=hidden_states)
if self.post_layernorm is None:
return encoder_outputs
last_hidden_state = self.post_layernorm(encoder_outputs) last_hidden_state = self.post_layernorm(encoder_outputs)
# TODO: add this back when pooled_output is used in inference # TODO: add this back when pooled_output is used in inference
# if self.use_head: # if self.use_head:
...@@ -512,8 +514,8 @@ class SiglipVisionModel(nn.Module): ...@@ -512,8 +514,8 @@ class SiglipVisionModel(nn.Module):
) )
@property @property
def need_post_layernorm(self): def _require_post_layernorm(self) -> bool:
return self.vision_model.need_post_layernorm return self.vision_model.post_layernorm is not None
def get_input_embeddings(self) -> nn.Module: def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding return self.vision_model.embeddings.patch_embedding
...@@ -529,13 +531,19 @@ class SiglipVisionModel(nn.Module): ...@@ -529,13 +531,19 @@ class SiglipVisionModel(nn.Module):
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers) layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights: for name, loaded_weight in weights:
# post_layernorm is optional in SiglipVisionModel # post_layernorm is optional in SiglipVisionModel
if ("vision_model.post_layernorm" in name if ("vision_model.post_layernorm" in name
and not self.need_post_layernorm): and not self._require_post_layernorm):
continue continue
# omit layers when num_hidden_layers_override is set # omit layers when num_hidden_layers_override is set
...@@ -544,7 +552,16 @@ class SiglipVisionModel(nn.Module): ...@@ -544,7 +552,16 @@ class SiglipVisionModel(nn.Module):
if layer_idx >= layer_count: if layer_idx >= layer_count:
continue continue
param = params_dict[name] for (param_name, weight_name, shard_id) in stacked_params_mapping:
weight_loader = getattr(param, "weight_loader", if weight_name not in name:
default_weight_loader) continue
weight_loader(param, loaded_weight)
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import NestedTensors from vllm.multimodal.base import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
...@@ -279,3 +280,18 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: ...@@ -279,3 +280,18 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
if name.startswith(missing_layer_name): if name.startswith(missing_layer_name):
return True return True
return False return False
def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
def make_empty_intermediate_tensors(
batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
key: torch.zeros((batch_size, hidden_size),
dtype=dtype,
device=device)
for key in keys
})
return make_empty_intermediate_tensors
...@@ -79,14 +79,12 @@ class MultiModalInputs(_MultiModalInputsBase): ...@@ -79,14 +79,12 @@ class MultiModalInputs(_MultiModalInputsBase):
if len(inputs_list) == 0: if len(inputs_list) == 0:
return {} return {}
keys = inputs_list[0].keys()
item_lists: Dict[str, List[NestedTensors]] = defaultdict(list) item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
for inputs in inputs_list: for inputs in inputs_list:
if inputs.keys() != keys: # For models that supports multiple modalities (e.g. Qwen2-VL),
msg = f"Inputs do not share the same keys ({keys})" # different modalities will return different data keys,
raise ValueError(msg) # so batch() should skip the same key check.
for k, v in inputs.items(): for k, v in inputs.items():
item_lists[k].append(v) item_lists[k].append(v)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment