Commit 4b4eeb26 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/main'

parents 2216a4e5 4fdc581f
...@@ -19,7 +19,8 @@ from vllm.attention import AttentionMetadata ...@@ -19,7 +19,8 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs) token_inputs)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import (AWQConfig,
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.models.intern_vit import (InternVisionModel, from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel) InternVisionPatchModel)
...@@ -418,11 +419,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -418,11 +419,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self._patch_quant_config(config, quant_config)
image_size = config.force_image_size or config.vision_config.image_size image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size patch_size = config.vision_config.patch_size
self.patch_size = patch_size self.patch_size = patch_size
self.select_layer = config.select_layer
self.num_image_token = int( self.num_image_token = int(
(image_size // patch_size)**2 * (config.downsample_ratio**2)) (image_size // patch_size)**2 * (config.downsample_ratio**2))
self.downsample_ratio = config.downsample_ratio self.downsample_ratio = config.downsample_ratio
...@@ -430,7 +431,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -430,7 +431,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.llm_arch_name = config.text_config.architectures[0] self.llm_arch_name = config.text_config.architectures[0]
self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM' self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
self.vision_model = self._init_vision_model(config, self.is_mono) self.vision_model = self._init_vision_model(
config,
quant_config=quant_config,
is_mono=self.is_mono,
prefix="vision_model",
)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config) config.text_config, cache_config, quant_config)
...@@ -441,6 +447,18 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -441,6 +447,18 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
def _patch_quant_config(self, config: PretrainedConfig,
quant_config: QuantizationConfig):
# the awq models from OpenGVLab missing `modules_to_not_convert`
# patch the quant_config to add `modules_to_not_convert` back
if isinstance(quant_config, AWQConfig):
text_config = config.text_config
llm_quant_config = getattr(text_config, "quantization_config",
None)
if (not quant_config.modules_to_not_convert) and \
(llm_quant_config is not None):
quant_config.modules_to_not_convert.append("vision_model")
@cached_property @cached_property
def sampler(self): def sampler(self):
if hasattr(self.language_model, "sampler"): if hasattr(self.language_model, "sampler"):
...@@ -448,17 +466,28 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -448,17 +466,28 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return Sampler() return Sampler()
def _init_vision_model(self, config: PretrainedConfig, is_mono: bool): def _init_vision_model(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
is_mono: bool,
prefix: str,
):
if not is_mono: if not is_mono:
vision_feature_layer = self.select_layer vision_feature_layer = config.select_layer
if vision_feature_layer < 0: if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \ num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1 + vision_feature_layer + 1
else: else:
num_hidden_layers = vision_feature_layer + 1 num_hidden_layers = vision_feature_layer + 1
return InternVisionModel( return InternVisionModel(
config.vision_config, config.vision_config,
num_hidden_layers_override=num_hidden_layers) quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
prefix=prefix,
)
else: else:
return InternVisionPatchModel(config.vision_config) return InternVisionPatchModel(config.vision_config)
......
from functools import cached_property from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Protocol,
TypedDict, Union) Tuple, TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig, from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
SiglipVisionConfig) PretrainedConfig, SiglipVisionConfig)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
...@@ -200,7 +200,17 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -200,7 +200,17 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
raise NotImplementedError(msg) raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaConfig): class LlavaLikeConfig(Protocol):
vision_config: PretrainedConfig
vision_feature_layer: int
def init_vision_tower_for_llava(
hf_config: LlavaLikeConfig,
quant_config: Optional[QuantizationConfig],
*,
require_post_norm: Optional[bool] = None,
):
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer # Initialize the vision tower only up to the required feature layer
...@@ -214,16 +224,24 @@ def _init_vision_tower(hf_config: LlavaConfig): ...@@ -214,16 +224,24 @@ def _init_vision_tower(hf_config: LlavaConfig):
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel( return CLIPVisionModel(
vision_config, vision_config,
quant_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
) )
elif isinstance(vision_config, SiglipVisionConfig): elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel( return SiglipVisionModel(
vision_config, vision_config,
quant_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
) )
elif isinstance(vision_config, PixtralVisionConfig): elif isinstance(vision_config, PixtralVisionConfig):
# TODO: allow layer override? return PixtralHFVisionModel(
return PixtralHFVisionModel(vision_config) vision_config,
quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
)
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
...@@ -255,7 +273,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -255,7 +273,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
config.projector_hidden_act = "gelu" config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = _init_vision_tower(config) self.vision_tower = init_vision_tower_for_llava(config, quant_config)
self.multi_modal_projector = LlavaMultiModalProjector( self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
......
...@@ -26,7 +26,7 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip, ...@@ -26,7 +26,7 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_clip_image_feature_size, dummy_seq_data_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip) get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .llava import LlavaMultiModalProjector from .llava import LlavaMultiModalProjector, init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size, dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
...@@ -259,32 +259,6 @@ def input_processor_for_llava_next(ctx: InputContext, ...@@ -259,32 +259,6 @@ def input_processor_for_llava_next(ctx: InputContext,
raise NotImplementedError(msg) raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaNextConfig):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
...@@ -303,7 +277,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -303,7 +277,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = _init_vision_tower(config) self.vision_tower = init_vision_tower_for_llava(config, quant_config)
self.image_newline = nn.Parameter( self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size)) torch.empty(config.text_config.hidden_size))
self.multi_modal_projector = LlavaMultiModalProjector( self.multi_modal_projector = LlavaMultiModalProjector(
......
...@@ -26,6 +26,7 @@ from vllm.utils import is_list_of ...@@ -26,6 +26,7 @@ from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip) dummy_seq_data_for_siglip)
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
...@@ -179,32 +180,6 @@ def input_processor_for_llava_next_video(ctx: InputContext, ...@@ -179,32 +180,6 @@ def input_processor_for_llava_next_video(ctx: InputContext,
raise NotImplementedError(msg) raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaNextVideoConfig):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
# adopted from transformers modeling_llava_next_video.py # adopted from transformers modeling_llava_next_video.py
class LlavaNextVideoPooler(nn.Module): class LlavaNextVideoPooler(nn.Module):
...@@ -281,7 +256,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -281,7 +256,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer # Initialize the vision tower only up to the required feature layer
self.vision_tower = _init_vision_tower(config) self.vision_tower = init_vision_tower_for_llava(config, quant_config)
self.vision_resampler = LlavaNextVideoPooler(config) self.vision_resampler = LlavaNextVideoPooler(config)
self.multi_modal_projector = LlavaNextMultiModalProjector( self.multi_modal_projector = LlavaNextMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
......
...@@ -31,6 +31,7 @@ from .clip import (CLIPVisionModel, dummy_seq_data_for_clip, ...@@ -31,6 +31,7 @@ from .clip import (CLIPVisionModel, dummy_seq_data_for_clip,
dummy_video_for_clip, get_clip_image_feature_size, dummy_video_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip) get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip, from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip, get_siglip_image_feature_size, dummy_video_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
...@@ -357,32 +358,6 @@ def input_processor_for_llava_onevision(ctx: InputContext, ...@@ -357,32 +358,6 @@ def input_processor_for_llava_onevision(ctx: InputContext,
raise NotImplementedError(msg) raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaOnevisionConfig):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
class LlavaOnevisionMultiModalProjector(nn.Module): class LlavaOnevisionMultiModalProjector(nn.Module):
def __init__(self, config: LlavaOnevisionConfig): def __init__(self, config: LlavaOnevisionConfig):
...@@ -425,7 +400,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -425,7 +400,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer # Initialize the vision tower only up to the required feature layer
self.vision_tower = _init_vision_tower(config) self.vision_tower = init_vision_tower_for_llava(config, quant_config)
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config) config.text_config, cache_config, quant_config)
......
...@@ -395,7 +395,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -395,7 +395,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.version = get_version_by_config(self.config) self.version = get_version_by_config(self.config)
self.llm = self.init_llm(config, cache_config, quant_config) self.llm = self.init_llm(config, cache_config, quant_config)
self.vpm = self.init_vision_module() self.vpm = self.init_vision_module(config, quant_config)
param_dtype = torch.get_default_dtype() param_dtype = torch.get_default_dtype()
self.vpm.to(dtype=param_dtype) self.vpm.to(dtype=param_dtype)
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
...@@ -647,7 +647,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -647,7 +647,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> nn.Module: ) -> nn.Module:
raise NotImplementedError raise NotImplementedError
def init_vision_module(self) -> nn.Module: def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
raise NotImplementedError raise NotImplementedError
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
...@@ -693,7 +697,11 @@ class MiniCPMV2_0(MiniCPMVBaseModel): ...@@ -693,7 +697,11 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
quant_config=quant_config), quant_config=quant_config),
name="model") name="model")
def init_vision_module(self) -> nn.Module: def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
# TODO :refactor this vision model # TODO :refactor this vision model
try: try:
import timm import timm
...@@ -817,8 +825,13 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): ...@@ -817,8 +825,13 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config), quant_config=quant_config),
name="model") name="model")
def init_vision_module(self) -> nn.Module: def init_vision_module(
model = Idefics2VisionTransformer(self.config.vision_config) self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
model = Idefics2VisionTransformer(config.vision_config,
quant_config=quant_config)
if self.config.drop_vision_last_layer: if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1] model.encoder.layers = model.encoder.layers[:-1]
return model return model
...@@ -929,9 +942,13 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): ...@@ -929,9 +942,13 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config), quant_config=quant_config),
name="model") name="model")
def init_vision_module(self) -> nn.Module: def init_vision_module(
self,
model = Idefics2VisionTransformer(self.config.vision_config) config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
model = Idefics2VisionTransformer(config.vision_config,
quant_config=quant_config)
if self.config.drop_vision_last_layer: if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1] model.encoder.layers = model.encoder.layers[:-1]
return model return model
......
...@@ -379,9 +379,13 @@ class MllamaVisionSdpaAttention(nn.Module): ...@@ -379,9 +379,13 @@ class MllamaVisionSdpaAttention(nn.Module):
class MllamaVisionEncoderLayer(nn.Module): class MllamaVisionEncoderLayer(nn.Module):
def __init__(self, def __init__(
config: config_mllama.MllamaVisionConfig, self,
is_gated: bool = False): config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
is_gated: bool = False,
) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -390,7 +394,9 @@ class MllamaVisionEncoderLayer(nn.Module): ...@@ -390,7 +394,9 @@ class MllamaVisionEncoderLayer(nn.Module):
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.self_attn = MllamaVisionSdpaAttention(config) self.self_attn = MllamaVisionSdpaAttention(config)
self.mlp = CLIPMLP(config) self.mlp = CLIPMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = nn.LayerNorm(self.hidden_size, self.input_layernorm = nn.LayerNorm(self.hidden_size,
eps=config.norm_eps) eps=config.norm_eps)
...@@ -427,16 +433,23 @@ class MllamaVisionEncoderLayer(nn.Module): ...@@ -427,16 +433,23 @@ class MllamaVisionEncoderLayer(nn.Module):
class MllamaVisionEncoder(nn.Module): class MllamaVisionEncoder(nn.Module):
def __init__(self, def __init__(
config: config_mllama.MllamaVisionConfig, self,
num_layers=32, config: config_mllama.MllamaVisionConfig,
is_gated=False, quant_config: Optional[QuantizationConfig],
output_hidden_states=None): num_layers: int = 32,
is_gated: bool = False,
output_hidden_states=None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MllamaVisionEncoderLayer(config, is_gated) MllamaVisionEncoderLayer(config,
for _ in range(num_layers) quant_config=quant_config,
is_gated=is_gated,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_layers)
]) ])
self.output_hidden_states = output_hidden_states or [] self.output_hidden_states = output_hidden_states or []
...@@ -463,8 +476,14 @@ class MllamaVisionEncoder(nn.Module): ...@@ -463,8 +476,14 @@ class MllamaVisionEncoder(nn.Module):
class MllamaVisionModel(nn.Module): class MllamaVisionModel(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig): def __init__(
self,
config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.image_size = config.image_size self.image_size = config.image_size
self.patch_size = config.patch_size self.patch_size = config.patch_size
self.max_num_tiles = config.max_num_tiles self.max_num_tiles = config.max_num_tiles
...@@ -500,12 +519,19 @@ class MllamaVisionModel(nn.Module): ...@@ -500,12 +519,19 @@ class MllamaVisionModel(nn.Module):
# encoders # encoders
self.transformer = MllamaVisionEncoder( self.transformer = MllamaVisionEncoder(
config, config,
quant_config,
config.num_hidden_layers, config.num_hidden_layers,
is_gated=False, is_gated=False,
output_hidden_states=config.intermediate_layers_indices) output_hidden_states=config.intermediate_layers_indices,
self.global_transformer = MllamaVisionEncoder(config, prefix=f"{prefix}.transformer",
config.num_global_layers, )
is_gated=True) self.global_transformer = MllamaVisionEncoder(
config,
quant_config,
config.num_global_layers,
is_gated=True,
prefix=f"{prefix}.global_transformer",
)
def apply_class_embedding(self, def apply_class_embedding(self,
hidden_state: torch.Tensor) -> torch.Tensor: hidden_state: torch.Tensor) -> torch.Tensor:
...@@ -648,6 +674,7 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -648,6 +674,7 @@ class MllamaTextCrossAttention(nn.Module):
config: Optional[config_mllama.MllamaTextConfig] = None, config: Optional[config_mllama.MllamaTextConfig] = None,
layer_idx: Optional[int] = None, layer_idx: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -673,6 +700,7 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -673,6 +700,7 @@ class MllamaTextCrossAttention(nn.Module):
self.num_key_value_heads, self.num_key_value_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim, self.num_heads * self.head_dim,
...@@ -680,6 +708,7 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -680,6 +708,7 @@ class MllamaTextCrossAttention(nn.Module):
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj",
) )
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue, # vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# use huggingface's instead # use huggingface's instead
...@@ -692,6 +721,7 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -692,6 +721,7 @@ class MllamaTextCrossAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
self.num_local_key_value_heads, self.num_local_key_value_heads,
prefix=f"{prefix}.attn",
) )
def forward( def forward(
...@@ -765,17 +795,19 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -765,17 +795,19 @@ class MllamaTextCrossAttention(nn.Module):
kv_len = k.shape[0] kv_len = k.shape[0]
q = q.transpose(0, 1).view(self.num_local_key_value_heads, q = q.transpose(0, 1).view(self.num_local_key_value_heads,
self.num_key_value_groups, q_len, self.num_key_value_groups, q_len,
self.head_dim) self.head_dim).contiguous()
k = k.transpose(0, k = k.transpose(0,
1)[:, 1)[:,
None, :, :].expand(self.num_local_key_value_heads, None, :, :].expand(self.num_local_key_value_heads,
self.num_key_value_groups, self.num_key_value_groups,
kv_len, self.head_dim) kv_len,
self.head_dim).contiguous()
v = v.transpose(0, v = v.transpose(0,
1)[:, 1)[:,
None, :, :].expand(self.num_local_key_value_heads, None, :, :].expand(self.num_local_key_value_heads,
self.num_key_value_groups, self.num_key_value_groups,
kv_len, self.head_dim) kv_len,
self.head_dim).contiguous()
attention_mask = attention_mask.view(1, 1, q_len, kv_len) attention_mask = attention_mask.view(1, 1, q_len, kv_len)
output = F.scaled_dot_product_attention(q, output = F.scaled_dot_product_attention(q,
k, k,
...@@ -791,15 +823,21 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): ...@@ -791,15 +823,21 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention """Cross-attention transformer block with tanh-gated attention
and feedforward.""" and feedforward."""
def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int, def __init__(
quant_config: Optional[QuantizationConfig]) \ self,
-> None: config: config_mllama.MllamaTextConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.cross_attn = MllamaTextCrossAttention( self.cross_attn = MllamaTextCrossAttention(
config=config, config=config,
layer_idx=layer_idx, layer_idx=layer_idx,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.cross_attn",
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
...@@ -811,6 +849,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): ...@@ -811,6 +849,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp",
) )
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -854,10 +893,15 @@ class MllamaTextModel(nn.Module): ...@@ -854,10 +893,15 @@ class MllamaTextModel(nn.Module):
config_class = config_mllama.MllamaTextConfig config_class = config_mllama.MllamaTextConfig
base_model_prefix = "model" base_model_prefix = "model"
def __init__(self, config: config_mllama.MllamaTextConfig, def __init__(
cache_config: Optional[CacheConfig], self,
quant_config: Optional[QuantizationConfig]): config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8,
...@@ -869,13 +913,20 @@ class MllamaTextModel(nn.Module): ...@@ -869,13 +913,20 @@ class MllamaTextModel(nn.Module):
if layer_idx in self.cross_attention_layers: if layer_idx in self.cross_attention_layers:
layers.append( layers.append(
MllamaCrossAttentionDecoderLayer( MllamaCrossAttentionDecoderLayer(
config, layer_idx, quant_config=quant_config)) config,
layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
))
else: else:
# TODO: force LlamaDecoderLayer to config.attention_bias=False # TODO: force LlamaDecoderLayer to config.attention_bias=False
layers.append( layers.append(
LlamaDecoderLayer(config, LlamaDecoderLayer(
cache_config=cache_config, config,
quant_config=quant_config)) cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
))
self.layers = nn.ModuleList(layers) self.layers = nn.ModuleList(layers)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -932,12 +983,19 @@ class MllamaForCausalLM(nn.Module): ...@@ -932,12 +983,19 @@ class MllamaForCausalLM(nn.Module):
"MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer" "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"
] ]
def __init__(self, config: config_mllama.MllamaTextConfig, def __init__(
cache_config: Optional[CacheConfig], self,
quant_config: Optional[QuantizationConfig]): config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.model = MllamaTextModel(config, cache_config, quant_config) self.model = MllamaTextModel(config,
cache_config,
quant_config,
prefix=f"{prefix}.model")
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
...@@ -994,11 +1052,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -994,11 +1052,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
config.pad_token_id if config.pad_token_id is not None else -1 config.pad_token_id if config.pad_token_id is not None else -1
self.image_size = config.vision_config.image_size self.image_size = config.vision_config.image_size
self.vision_model = MllamaVisionModel(config.vision_config) self.vision_model = MllamaVisionModel(config.vision_config,
quant_config,
prefix="vision_model")
self.language_model = MllamaForCausalLM( self.language_model = MllamaForCausalLM(
config.text_config, config.text_config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix="language_model",
) )
self.multi_modal_projector = nn.Linear( self.multi_modal_projector = nn.Linear(
config.vision_config.vision_output_dim, config.vision_config.vision_output_dim,
......
...@@ -30,21 +30,21 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -30,21 +30,21 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import make_layers
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData) SequenceData)
from vllm.transformers_utils.processor import get_processor from vllm.transformers_utils.processor import get_processor
from .utils import get_vit_attn_backend from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (get_vit_attn_backend,
make_empty_intermediate_tensors_factory, make_layers)
# TODO: hard-coded for now. Consider making it configurable. # TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9] VIT_LAYERS = [-2, -9]
...@@ -744,6 +744,10 @@ class MolmoModel(nn.Module): ...@@ -744,6 +744,10 @@ class MolmoModel(nn.Module):
assert config.layer_norm_type == "rms" assert config.layer_norm_type == "rms"
self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps) self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -925,16 +929,19 @@ def pad_images( ...@@ -925,16 +929,19 @@ def pad_images(
def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
prompt = inputs.get("prompt", None) prompt = inputs.get("prompt")
multi_modal_data = inputs.get("multi_modal_data", None) multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is not None: image = None if multi_modal_data is None else multi_modal_data.get("image")
image = multi_modal_data.get("image", None)
else:
image = None
processor = cached_get_processor(ctx.model_config.model, processor = cached_get_processor(ctx.model_config.model,
trust_remote_code=True, trust_remote_code=True,
revision=ctx.model_config.code_revision) revision=ctx.model_config.code_revision)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
# NOTE: message formatting for raw text prompt is only applied for # NOTE: message formatting for raw text prompt is only applied for
# offline inference; for online inference, the prompt is always in # offline inference; for online inference, the prompt is always in
# instruction format and tokenized. # instruction format and tokenized.
...@@ -997,9 +1004,13 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -997,9 +1004,13 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = dict(image=image_data) multi_modal_data = dict(image=image_data)
prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(out["input_ids"])
return token_inputs( return token_inputs(
prompt_token_ids=out["input_ids"], prompt_token_ids=out["input_ids"],
prompt=inputs["prompt"], prompt=prompt,
multi_modal_data=multi_modal_data, multi_modal_data=multi_modal_data,
) )
...@@ -1008,7 +1019,7 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -1008,7 +1019,7 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo)
@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) @INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
class MolmoForCausalLM(nn.Module, SupportsMultiModal): class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -1040,6 +1051,9 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal): ...@@ -1040,6 +1051,9 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal):
or config.vocab_size) or config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, self,
**kwargs: object, **kwargs: object,
...@@ -1123,31 +1137,36 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal): ...@@ -1123,31 +1137,36 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal):
positions: torch.LongTensor, positions: torch.LongTensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> SamplerOutput:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None:
inputs_embeds = self.model.embed_tokens(input_ids)
if image_input is not None: image_features = self._process_image_input(image_input)
inputs_embeds = self.model.embed_tokens(input_ids)
image_features = self._process_image_input(image_input)
inputs_embeds = self._merge_multimodal_embeddings( inputs_embeds = self._merge_multimodal_embeddings(
inputs_embeds, inputs_embeds,
image_features, image_features,
image_input["image_input_idx"], image_input["image_input_idx"],
image_input["seq_len"], image_input["seq_len"],
) )
input_ids = None input_ids = None
else: else:
inputs_embeds = None inputs_embeds = None
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
......
...@@ -4,10 +4,13 @@ ...@@ -4,10 +4,13 @@
# Copyright (c) 2024 NVIDIA # Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details] # Licensed under Apache 2.0 License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from typing import Optional
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from .intern_vit import InternVisionModel from .intern_vit import InternVisionModel
...@@ -55,10 +58,31 @@ class NVLM_D_Model(InternVLChatModel): ...@@ -55,10 +58,31 @@ class NVLM_D_Model(InternVLChatModel):
nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False), nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False),
) )
def _init_vision_model(self, config: PretrainedConfig, def _init_vision_model(
num_hidden_layers: int): self,
# We added additional dummy heads to the original num of heads to make config: PretrainedConfig,
# the number of heads divisible by 8. quant_config: Optional[QuantizationConfig],
return InternVisionModel(config.vision_config, *,
num_hidden_layers_override=num_hidden_layers, is_mono: bool,
num_dummy_heads=7) prefix: str,
):
if not is_mono:
vision_feature_layer = config.select_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
# We added additional dummy heads to the original num of heads to
# make the number of heads divisible by 8.
return InternVisionModel(
config.vision_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
num_dummy_heads=7,
prefix=prefix,
)
else:
msg = "Monolith mode is not applicable to NVLM_D"
raise NotImplementedError(msg)
...@@ -142,7 +142,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -142,7 +142,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.vision_tower = SiglipVisionModel(config.vision_config) self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector( self.multi_modal_projector = PaliGemmaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
projection_dim=config.vision_config.projection_dim) projection_dim=config.vision_config.projection_dim)
......
...@@ -70,7 +70,8 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0, ...@@ -70,7 +70,8 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
projection_dim=768) projection_dim=768)
def _init_img_processor(hf_config: PretrainedConfig): def _init_img_processor(hf_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig]):
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
layer_idx = hf_config.img_processor.get('layer_idx', -2) layer_idx = hf_config.img_processor.get('layer_idx', -2)
...@@ -82,7 +83,10 @@ def _init_img_processor(hf_config: PretrainedConfig): ...@@ -82,7 +83,10 @@ def _init_img_processor(hf_config: PretrainedConfig):
num_hidden_layers = layer_idx + 1 num_hidden_layers = layer_idx + 1
img_processor = CLIPVisionModel( img_processor = CLIPVisionModel(
clip_config, num_hidden_layers_override=num_hidden_layers) clip_config,
quant_config,
num_hidden_layers_override=num_hidden_layers,
)
return img_processor return img_processor
...@@ -148,14 +152,15 @@ class Phi3ImageEmbeddingBase(nn.Module): ...@@ -148,14 +152,15 @@ class Phi3ImageEmbeddingBase(nn.Module):
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
"""Phi3 Image embedding with HD transform.""" """Phi3 Image embedding with HD transform."""
def __init__(self, config: PretrainedConfig) -> None: def __init__(self, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig]) -> None:
super().__init__() super().__init__()
# n_embed or hidden_size # n_embed or hidden_size
hidden_size = config.n_embd if hasattr( hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size config, 'n_embd') else config.hidden_size
self.img_processor = _init_img_processor(config) self.img_processor = _init_img_processor(config, quant_config)
image_dim_out = config.img_processor['image_dim_out'] image_dim_out = config.img_processor['image_dim_out']
self.num_img_tokens = config.img_processor['num_img_tokens'] self.num_img_tokens = config.img_processor['num_img_tokens']
...@@ -535,7 +540,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -535,7 +540,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) )
# TODO: Optionally initializes this for supporting input embeddings. # TODO: Optionally initializes this for supporting input embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(config) self.vision_embed_tokens = Phi3HDImageEmbedding(config, quant_config)
self.language_model = LlamaForCausalLM(config, cache_config, self.language_model = LlamaForCausalLM(config, cache_config,
quant_config) quant_config)
......
...@@ -767,9 +767,17 @@ def input_processor_for_pixtral_hf( ...@@ -767,9 +767,17 @@ def input_processor_for_pixtral_hf(
class PixtralHFMLP(nn.Module): class PixtralHFMLP(nn.Module):
def __init__(self, config: PixtralVisionConfig): def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
assert config.intermediate_size is not None assert config.intermediate_size is not None
# TODO: Use quant_config and prefix after optimizing this
self.gate_proj = nn.Linear(config.hidden_size, self.gate_proj = nn.Linear(config.hidden_size,
config.intermediate_size, config.intermediate_size,
bias=False) bias=False)
...@@ -787,8 +795,15 @@ class PixtralHFMLP(nn.Module): ...@@ -787,8 +795,15 @@ class PixtralHFMLP(nn.Module):
class PixtralHFAttention(nn.Module): class PixtralHFAttention(nn.Module):
def __init__(self, config: PixtralVisionConfig): def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
assert not config.hidden_size % config.num_attention_heads assert not config.hidden_size % config.num_attention_heads
self.n_heads = config.num_attention_heads self.n_heads = config.num_attention_heads
...@@ -796,6 +811,7 @@ class PixtralHFAttention(nn.Module): ...@@ -796,6 +811,7 @@ class PixtralHFAttention(nn.Module):
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
# TODO: Use quant_config and prefix after optimizing this
self.q_proj = nn.Linear(config.hidden_size, self.q_proj = nn.Linear(config.hidden_size,
config.hidden_size, config.hidden_size,
bias=False) bias=False)
...@@ -840,11 +856,22 @@ class PixtralHFAttention(nn.Module): ...@@ -840,11 +856,22 @@ class PixtralHFAttention(nn.Module):
class PixtralHFTransformerBlock(nn.Module): class PixtralHFTransformerBlock(nn.Module):
def __init__(self, config: PixtralVisionConfig): def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5) self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
self.attention = PixtralHFAttention(config) self.attention = PixtralHFAttention(config,
self.feed_forward = PixtralHFMLP(config) quant_config=quant_config,
prefix=f"{prefix}.attention")
self.feed_forward = PixtralHFMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward")
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5) self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
def forward( def forward(
...@@ -864,11 +891,27 @@ class PixtralHFTransformerBlock(nn.Module): ...@@ -864,11 +891,27 @@ class PixtralHFTransformerBlock(nn.Module):
class PixtralHFTransformer(nn.Module): class PixtralHFTransformer(nn.Module):
def __init__(self, config: PixtralVisionConfig): def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(config.num_hidden_layers): if num_hidden_layers_override is None:
self.layers.append(PixtralHFTransformerBlock(config)) num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
PixtralHFTransformerBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])
def forward( def forward(
self, self,
...@@ -883,7 +926,15 @@ class PixtralHFTransformer(nn.Module): ...@@ -883,7 +926,15 @@ class PixtralHFTransformer(nn.Module):
class PixtralHFVisionModel(nn.Module): class PixtralHFVisionModel(nn.Module):
def __init__(self, config: PixtralVisionConfig): def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -895,7 +946,24 @@ class PixtralHFVisionModel(nn.Module): ...@@ -895,7 +946,24 @@ class PixtralHFVisionModel(nn.Module):
bias=False, bias=False,
) )
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5) self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
self.transformer = PixtralHFTransformer(config) self.transformer = PixtralHFTransformer(
config,
quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.transformer",
)
num_hidden_layers = config.num_hidden_layers
if len(self.transformer.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.transformer.layers)} "
"layers.")
if require_post_norm is True:
msg = "PixtralHFVisionModel does not have post-layernorm"
raise ValueError(msg)
self.dtype = next(self.parameters()).dtype self.dtype = next(self.parameters()).dtype
self.device = next(self.parameters()).device self.device = next(self.parameters()).device
self.patch_positional_embedding = PixtralRotaryEmbedding( self.patch_positional_embedding = PixtralRotaryEmbedding(
......
# coding=utf-8
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from functools import lru_cache
from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union
import librosa
import numpy as np
import torch
import torch.nn as nn
from transformers import Qwen2AudioConfig, Qwen2AudioEncoder
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
from vllm.sequence import IntermediateTensors, SequenceData
from .interfaces import SupportsMultiModal, SupportsPP
logger = init_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
# # === Audio Inputs === #
class Qwen2AudioInputs(TypedDict):
input_features: torch.Tensor
"""Shape:
`(num_audios, num_mel_bins, 3000)`
"""
feature_attention_mask: torch.Tensor
"""Shape: `(num_audios, 3000)`
"""
# === Audio Encoder === #
class Qwen2AudioMultiModalProjector(nn.Module):
def __init__(self, audio_hidden_size: int, text_hidden_size: int):
super().__init__()
self.linear = nn.Linear(audio_hidden_size, text_hidden_size, bias=True)
def forward(self, audio_features):
hidden_states = self.linear(audio_features)
return hidden_states
def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_audios = mm_counts["audio"]
max_llm_audio_tokens = get_max_qwen2_audio_audio_tokens(ctx) * num_audios
if seq_len - max_llm_audio_tokens - 2 < 0:
raise RuntimeError(
f"Qwen2-Audio cannot process {num_audios} audios in a prompt, "
"please increase max_model_len or reduce audio limit by "
"--limit-mm-per-prompt.")
audio_token_index = ctx.model_config.hf_config.audio_token_index
dummy_seqdata = SequenceData.from_prompt_token_counts(
(audio_token_index, max_llm_audio_tokens),
(0, seq_len - max_llm_audio_tokens),
)
dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.)
return dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}
def get_processor(
processor_name: str,
*args,
trust_remote_code: bool = False,
**kwargs,
):
"""Gets a processor for the given model name via HuggingFace.
Derived from `vllm.transformers_utils.image_processor.get_image_processor`.
"""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor
try:
processor = AutoProcessor.from_pretrained(
processor_name,
*args,
trust_remote_code=trust_remote_code,
**kwargs)
except ValueError as e:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
# Unlike AutoTokenizer, AutoProcessor does not separate such errors
if not trust_remote_code:
err_msg = (
"Failed to load the processor. If the processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
return processor
cached_get_processor = lru_cache(get_processor)
def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
"""
Computes the output length of the convolutional layers
and the output length of the audio encoder
"""
input_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (input_lengths - 2) // 2 + 1
return input_lengths, output_lengths
def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int:
max_source_position = (
ctx.model_config.hf_config.audio_config.max_source_positions)
output_lengths = (max_source_position - 2) // 2 + 1
return output_lengths
def input_processor_for_qwen2_audio(
ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return inputs
audios = multi_modal_data["audio"]
if not isinstance(audios, list):
audios = [audios]
if len(audios) == 0:
return inputs
processor = cached_get_processor(ctx.model_config.model)
resampled_audios = [
librosa.resample(audio,
orig_sr=sampling_rate,
target_sr=processor.feature_extractor.sampling_rate)
for audio, sampling_rate in audios
]
audio_input_lengths = np.array(
[min(3000, _.shape[0] // 160 + 1) for _ in resampled_audios])
audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths(
audio_input_lengths)
audio_token_index = ctx.model_config.hf_config.audio_token_index
input_ids = inputs['prompt_token_ids']
new_input_ids = []
audio_num = input_ids.count(audio_token_index)
assert len(audio_input_lengths) == audio_num, \
(f'The text input contains {audio_num} audio tokens, '
f'but {len(audio_input_lengths)} audios provided')
start = 0
for audio_idx in range(audio_num):
end = input_ids.index(audio_token_index, start)
new_input_ids.extend(input_ids[start:end]) # text part
new_input_ids.extend([audio_token_index] *
audio_output_lengths[audio_idx])
start = end + 1
new_input_ids.extend(input_ids[start:])
return token_inputs(
prompt_token_ids=new_input_ids,
prompt=inputs['prompt'],
multi_modal_data=multi_modal_data,
)
def input_mapper_for_qwen2_audio(
ctx: InputContext,
multi_modal_data: Union[np.ndarray, List[np.ndarray]],
) -> MultiModalInputs:
"""Input mapper for Qwen2-Audio."""
if not isinstance(multi_modal_data, list):
multi_modal_data = [multi_modal_data]
if len(multi_modal_data) == 0:
return MultiModalInputs()
processor = cached_get_processor(ctx.model_config.model)
audio_feature_extractor = processor.feature_extractor
if audio_feature_extractor is None:
raise RuntimeError(
"No HuggingFace audio_feature_extractor is available "
"to process the audio object")
try:
resampled_audios = [
librosa.resample(
audio,
orig_sr=sampling_rate,
target_sr=processor.feature_extractor.sampling_rate)
for audio, sampling_rate in multi_modal_data
]
batch_data = audio_feature_extractor(resampled_audios,
sampling_rate=16000,
return_attention_mask=True,
padding="max_length",
return_tensors="pt").data
batch_data["feature_attention_mask"] = batch_data.pop("attention_mask")
except Exception:
logger.error("Failed to process audio (%s)", multi_modal_data)
raise
return MultiModalInputs(batch_data)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_audio)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_audio)
@MULTIMODAL_REGISTRY.register_input_mapper("audio",
input_mapper_for_qwen2_audio)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_qwen2_audio_audio_tokens)
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self,
config: Qwen2AudioConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.multimodal_config = multimodal_config
self.audio_tower = Qwen2AudioEncoder(config.audio_config)
self.multi_modal_projector = Qwen2AudioMultiModalProjector(
config.audio_config.d_model, config.text_config.hidden_size)
self.quant_config = quant_config
self.language_model = Qwen2Model(config.text_config, cache_config,
quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
if config.text_config.tie_word_embeddings:
self.lm_head = self.language_model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.text_config.vocab_size,
config.text_config.hidden_size,
quant_config=quant_config)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.text_config.vocab_size,
logit_scale)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _validate_and_reshape_mm_tensor(self,
mm_input: Union[torch.Tensor,
List[torch.Tensor]],
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
return torch.concat(list(mm_input))
else:
return torch.concat(mm_input)
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[Qwen2AudioInputs]:
input_features = kwargs.pop('input_features', None)
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
if input_features is None:
return None
input_features = self._validate_and_reshape_mm_tensor(
input_features, 'input_features')
feature_attention_mask = self._validate_and_reshape_mm_tensor(
feature_attention_mask, 'feature_attention_mask')
if not isinstance(input_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio input features. "
f"Got type: {type(input_features)}")
return Qwen2AudioInputs(input_features=input_features,
feature_attention_mask=feature_attention_mask)
def _process_audio_input(self,
audio_input: Qwen2AudioInputs) -> torch.Tensor:
input_features = audio_input["input_features"]
feature_attention_mask = audio_input["feature_attention_mask"]
audio_feat_lengths, audio_output_lengths = (
self.audio_tower._get_feat_extract_output_lengths(
feature_attention_mask.sum(-1)))
batch_size, _, max_mel_seq_len = input_features.shape
max_seq_len = (max_mel_seq_len - 2) // 2 + 1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range = (torch.arange(
0,
max_seq_len,
dtype=audio_feat_lengths.dtype,
device=audio_feat_lengths.device).unsqueeze(0).expand(
batch_size, max_seq_len))
lengths_expand = audio_feat_lengths.unsqueeze(-1).expand(
batch_size, max_seq_len)
# Create mask
padding_mask = seq_range >= lengths_expand
audio_attention_mask_ = padding_mask.view(
batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len,
max_seq_len)
audio_attention_mask = audio_attention_mask_.to(
dtype=self.audio_tower.conv1.weight.dtype,
device=self.audio_tower.conv1.weight.device)
audio_attention_mask[audio_attention_mask_] = float("-inf")
audio_outputs = self.audio_tower(input_features,
attention_mask=audio_attention_mask)
selected_audio_feature = audio_outputs.last_hidden_state
audio_features = self.multi_modal_projector(selected_audio_feature)
num_audios, max_audio_tokens, embed_dim = audio_features.shape
audio_features_mask = torch.arange(max_audio_tokens).expand(
num_audios, max_audio_tokens
).to(audio_output_lengths.device) < audio_output_lengths.unsqueeze(1)
masked_audio_features = audio_features[audio_features_mask].view(
-1, embed_dim)
return masked_audio_features
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
inputs_embeds = None
else:
inputs_embeds = self.language_model.embed_tokens(input_ids)
masked_audio_features = self._process_audio_input(audio_input)
# merge llm embeddings and audio features
mask = (input_ids == self.config.audio_token_index)
inputs_embeds[mask, :] = masked_audio_features
input_ids = None
hidden_states = self.language_model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if (self.config.text_config.tie_word_embeddings
and "lm_head.weight" in name):
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name or 'audio' in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -119,5 +119,6 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP): ...@@ -119,5 +119,6 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["lm_head."])
loader.load_weights(weights) loader.load_weights(weights)
...@@ -61,6 +61,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, ...@@ -61,6 +61,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalInputs) MultiModalInputs)
from vllm.multimodal.base import MultiModalData from vllm.multimodal.base import MultiModalData
from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
...@@ -549,6 +550,9 @@ def mm_input_mapper_for_qwen2_vl( ...@@ -549,6 +550,9 @@ def mm_input_mapper_for_qwen2_vl(
ctx: InputContext, ctx: InputContext,
data: MultiModalData[object], data: MultiModalData[object],
data_type_key: str, data_type_key: str,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
"""Input mapper for Qwen2-VL.""" """Input mapper for Qwen2-VL."""
if data_type_key == "image" and isinstance(data, dict): if data_type_key == "image" and isinstance(data, dict):
...@@ -557,8 +561,19 @@ def mm_input_mapper_for_qwen2_vl( ...@@ -557,8 +561,19 @@ def mm_input_mapper_for_qwen2_vl(
"image_grid_thw": data.get("image_grid_thw"), "image_grid_thw": data.get("image_grid_thw"),
}) })
model_config = ctx.model_config model_config = ctx.model_config
# Handle mm processor kwargs; we pass these at creation time
# because preprocess() in transformers doesn't expose them
mm_processor_kwargs = {}
if min_pixels:
mm_processor_kwargs["min_pixels"] = min_pixels
if max_pixels:
mm_processor_kwargs["max_pixels"] = max_pixels
image_processor = cached_get_image_processor( image_processor = cached_get_image_processor(
model_config.model, trust_remote_code=model_config.trust_remote_code) model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs,
)
if image_processor is None: if image_processor is None:
raise RuntimeError("No HuggingFace processor is available " raise RuntimeError("No HuggingFace processor is available "
"to process the image object") "to process the image object")
...@@ -631,25 +646,36 @@ def _get_max_image_info( ...@@ -631,25 +646,36 @@ def _get_max_image_info(
image_processor, image_processor,
data_type_key: str = "image", data_type_key: str = "image",
mm_count: int = 1, mm_count: int = 1,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
): ):
# Limit min / max pixels unless they're explicitly provided
if min_pixels is None:
min_pixels = max(image_processor.min_pixels, 28 * 28)
if max_pixels is None:
max_pixels = min(image_processor.max_pixels, 1280 * 28 * 28)
return _get_vision_info( return _get_vision_info(
image_processor, image_processor,
height=9999999, height=9999999,
width=9999999, width=9999999,
min_pixels=min_pixels,
# Limit min / max pixels. max_pixels=max_pixels,
min_pixels=max(image_processor.min_pixels, 28 * 28),
max_pixels=min(image_processor.max_pixels, 1280 * 28 * 28),
data_type_key=data_type_key, data_type_key=data_type_key,
mm_count=mm_count, mm_count=mm_count,
) )
def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int: def get_max_qwen2_vl_mm_tokens(ctx: InputContext,
data_type_key: str,
*,
min_pixels=None,
max_pixels=None) -> int:
image_processor = cached_get_image_processor(ctx.model_config.model) image_processor = cached_get_image_processor(ctx.model_config.model)
max_resized_height, max_resized_width, max_llm_image_tokens = \ max_resized_height, max_resized_width, max_llm_image_tokens = \
_get_max_image_info(image_processor, data_type_key=data_type_key, _get_max_image_info(image_processor, data_type_key=data_type_key,
mm_count=1) mm_count=1, min_pixels=min_pixels,
max_pixels=max_pixels)
return max_llm_image_tokens return max_llm_image_tokens
...@@ -660,14 +686,20 @@ get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens, ...@@ -660,14 +686,20 @@ get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens,
def dummy_data_for_qwen2_vl( def dummy_data_for_qwen2_vl(
ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None
) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: ) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
image_processor = cached_get_image_processor(ctx.model_config.model) image_processor = cached_get_image_processor(ctx.model_config.model)
num_images = mm_counts["image"] num_images = mm_counts["image"]
max_resized_height, max_resized_width, max_llm_image_tokens = \ max_resized_height, max_resized_width, max_llm_image_tokens = \
_get_max_image_info(image_processor, data_type_key="image", _get_max_image_info(image_processor, data_type_key="image",
mm_count=num_images) mm_count=num_images, min_pixels=min_pixels,
max_pixels=max_pixels)
if seq_len - max_llm_image_tokens - 2 < 0: if seq_len - max_llm_image_tokens - 2 < 0:
raise RuntimeError( raise RuntimeError(
f"Qwen2-VL cannot process {num_images} images in a prompt, " f"Qwen2-VL cannot process {num_images} images in a prompt, "
...@@ -678,10 +710,11 @@ def dummy_data_for_qwen2_vl( ...@@ -678,10 +710,11 @@ def dummy_data_for_qwen2_vl(
num_videos = mm_counts["video"] num_videos = mm_counts["video"]
max_resized_height, max_resized_width, max_llm_video_tokens = \ max_resized_height, max_resized_width, max_llm_video_tokens = \
_get_max_image_info(image_processor, data_type_key="video", _get_max_image_info(image_processor, data_type_key="video",
mm_count=num_videos) mm_count=num_videos, min_pixels=min_pixels,
max_pixels=max_pixels)
if seq_len - max_llm_video_tokens - 2 < 0: if seq_len - max_llm_video_tokens - 2 < 0:
raise RuntimeError( raise RuntimeError(
f"Qwen2-VL cannot process {num_images} videos in a prompt, " f"Qwen2-VL cannot process {num_videos} videos in a prompt, "
"please increase max_model_len or reduce video limit by " "please increase max_model_len or reduce video limit by "
"--limit-mm-per-prompt.") "--limit-mm-per-prompt.")
...@@ -706,6 +739,8 @@ def _get_llm_num_vision_tokens( ...@@ -706,6 +739,8 @@ def _get_llm_num_vision_tokens(
mm_inputs: list, mm_inputs: list,
data_type_key: str, data_type_key: str,
image_processor, image_processor,
min_pixels: int,
max_pixels: int,
): ):
"""Get number of vision tokens of multimodal inputs. """Get number of vision tokens of multimodal inputs.
...@@ -715,12 +750,13 @@ def _get_llm_num_vision_tokens( ...@@ -715,12 +750,13 @@ def _get_llm_num_vision_tokens(
image = to_numpy_array(mm_inputs[0]) image = to_numpy_array(mm_inputs[0])
input_data_format = infer_channel_dimension_format(image) input_data_format = infer_channel_dimension_format(image)
height, width = get_image_size(image, channel_dim=input_data_format) height, width = get_image_size(image, channel_dim=input_data_format)
_, _, llm_num_vision_tokens = _get_vision_info( _, _, llm_num_vision_tokens = _get_vision_info(
image_processor, image_processor,
height=height, height=height,
width=width, width=width,
min_pixels=image_processor.min_pixels, min_pixels=min_pixels,
max_pixels=image_processor.max_pixels, max_pixels=max_pixels,
do_resize=image_processor.do_resize, do_resize=image_processor.do_resize,
data_type_key=data_type_key, data_type_key=data_type_key,
mm_count=len(mm_inputs), mm_count=len(mm_inputs),
...@@ -730,7 +766,8 @@ def _get_llm_num_vision_tokens( ...@@ -730,7 +766,8 @@ def _get_llm_num_vision_tokens(
def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
data_type_key: str, image_processor: Any, data_type_key: str, image_processor: Any,
prompt_token_ids: List[int]) -> List[int]: prompt_token_ids: List[int], min_pixels: Optional[int],
max_pixels: Optional[int]) -> List[int]:
""" """
Expand pad tokens for multi-modal inputs (e.g., images or videos). Expand pad tokens for multi-modal inputs (e.g., images or videos).
...@@ -741,6 +778,8 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, ...@@ -741,6 +778,8 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
data_type_key (str): The type of the multi-modal input. data_type_key (str): The type of the multi-modal input.
image_processor (Any): The image processor used to process the inputs. image_processor (Any): The image processor used to process the inputs.
prompt_token_ids (List[int]): The list of token IDs in the prompt. prompt_token_ids (List[int]): The list of token IDs in the prompt.
min_pixels (int): min pixels to used for img processing
max_pixels (int): max pixels to be used for img processing
Returns: Returns:
List[int]: The list of token IDs for the multi-modal inputs. List[int]: The list of token IDs for the multi-modal inputs.
...@@ -757,6 +796,8 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, ...@@ -757,6 +796,8 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
[data] if data_type_key == "image" else data, [data] if data_type_key == "image" else data,
data_type_key=data_type_key, data_type_key=data_type_key,
image_processor=image_processor, image_processor=image_processor,
min_pixels=min_pixels,
max_pixels=max_pixels,
) )
if cnt == 0: if cnt == 0:
end_idx = indices[cnt] end_idx = indices[cnt]
...@@ -773,8 +814,11 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, ...@@ -773,8 +814,11 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
def input_processor_for_qwen2_vl( def input_processor_for_qwen2_vl(
ctx: InputContext, ctx: InputContext,
inputs: DecoderOnlyInputs, inputs: DecoderOnlyInputs,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
multi_modal_data = inputs.get("multi_modal_data", None) multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None: if multi_modal_data is None:
return inputs return inputs
...@@ -783,6 +827,11 @@ def input_processor_for_qwen2_vl( ...@@ -783,6 +827,11 @@ def input_processor_for_qwen2_vl(
processor = cached_get_processor(ctx.model_config.model) processor = cached_get_processor(ctx.model_config.model)
image_processor = processor.image_processor image_processor = processor.image_processor
# Apply processor kwarg overrides for image processor options
min_pixels = min_pixels if min_pixels else image_processor.min_pixels
max_pixels = max_pixels if max_pixels else image_processor.max_pixels
model_config = ctx.model_config
hf_config = ctx.get_hf_config(Qwen2VLConfig) hf_config = ctx.get_hf_config(Qwen2VLConfig)
# To avoid redundant processing of vision objects (resize, rescale, etc.), # To avoid redundant processing of vision objects (resize, rescale, etc.),
...@@ -798,14 +847,11 @@ def input_processor_for_qwen2_vl( ...@@ -798,14 +847,11 @@ def input_processor_for_qwen2_vl(
# return_tensors="pt") # return_tensors="pt")
# prompt_token_ids = inputs["input_ids"][0].tolist() # prompt_token_ids = inputs["input_ids"][0].tolist()
prompt_token_ids = inputs.get("prompt_token_ids", None) tokenizer = cached_get_tokenizer(
if prompt_token_ids is None: model_config.tokenizer,
prompt = inputs["prompt"] trust_remote_code=model_config.trust_remote_code)
prompt_token_ids = processor.tokenizer(
prompt, prompt_token_ids = inputs["prompt_token_ids"]
padding=True,
return_tensors=None,
)["input_ids"]
# Expand image pad tokens. # Expand image pad tokens.
...@@ -830,20 +876,30 @@ def input_processor_for_qwen2_vl( ...@@ -830,20 +876,30 @@ def input_processor_for_qwen2_vl(
else: else:
prompt_token_ids = _expand_pad_tokens(image_inputs, prompt_token_ids = _expand_pad_tokens(image_inputs,
hf_config.image_token_id, hf_config.image_token_id,
make_batched_images, "image", make_batched_images,
"image",
image_processor, image_processor,
prompt_token_ids) prompt_token_ids,
min_pixels=min_pixels,
max_pixels=max_pixels)
if video_inputs is not None: if video_inputs is not None:
prompt_token_ids = _expand_pad_tokens(video_inputs, prompt_token_ids = _expand_pad_tokens(video_inputs,
hf_config.video_token_id, hf_config.video_token_id,
make_batched_videos, "video", make_batched_videos,
"video",
image_processor, image_processor,
prompt_token_ids) prompt_token_ids,
min_pixels=min_pixels,
max_pixels=max_pixels)
prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
return token_inputs( return token_inputs(
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
prompt=inputs["prompt"], prompt=prompt,
multi_modal_data=multi_modal_data, multi_modal_data=multi_modal_data,
) )
......
...@@ -26,8 +26,10 @@ _TEXT_GENERATION_MODELS = { ...@@ -26,8 +26,10 @@ _TEXT_GENERATION_MODELS = {
"AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b # baichuan-7b, upper case 'C' in the class name
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
# baichuan-13b, lower case 'c' in the class name
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
"BloomForCausalLM": ("bloom", "BloomForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
# ChatGLMModel supports multimodal # ChatGLMModel supports multimodal
"CohereForCausalLM": ("commandr", "CohereForCausalLM"), "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
...@@ -85,6 +87,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -85,6 +87,7 @@ _TEXT_GENERATION_MODELS = {
# [Encoder-decoder] # [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"), "BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
} }
_EMBEDDING_MODELS = { _EMBEDDING_MODELS = {
...@@ -118,6 +121,7 @@ _MULTIMODAL_MODELS = { ...@@ -118,6 +121,7 @@ _MULTIMODAL_MODELS = {
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"), "UltravoxModel": ("ultravox", "UltravoxModel"),
# [Encoder-decoder] # [Encoder-decoder]
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
......
This diff is collapsed.
...@@ -117,6 +117,9 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): ...@@ -117,6 +117,9 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
if not isinstance(data, list): if not isinstance(data, list):
data = [data] data = [data]
if len(data) == 0:
return MultiModalInputs()
# If the audio inputs are embeddings, no need for preprocessing # If the audio inputs are embeddings, no need for preprocessing
if is_list_of(data, torch.Tensor, check="all"): if is_list_of(data, torch.Tensor, check="all"):
return MultiModalInputs({"audio_embeds": data}) return MultiModalInputs({"audio_embeds": data})
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment