"tests/vscode:/vscode.git/clone" did not exist on "c37c0af990ed1f3623448b82903c1ae52e84cc05"
Commit 539aa992 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.2' into v0.6.2-dev

parents 93872128 7193774b
# ruff: noqa: SIM117 # ruff: noqa: SIM117
import collections import collections
import copy import copy
import dataclasses
import fnmatch import fnmatch
import glob import glob
import json import json
...@@ -8,7 +9,8 @@ import math ...@@ -8,7 +9,8 @@ import math
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Tuple, Type from typing import (Any, Dict, Generator, Iterable, List, Optional, Tuple,
Type, cast)
import gguf import gguf
import huggingface_hub import huggingface_hub
...@@ -22,6 +24,8 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME ...@@ -22,6 +24,8 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig, LoRAConfig, ModelConfig, MultiModalConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -95,10 +99,10 @@ def _get_quantization_config( ...@@ -95,10 +99,10 @@ def _get_quantization_config(
"""Get the quantization config.""" """Get the quantization config."""
if model_config.quantization is not None: if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config) quant_config = get_quant_config(model_config, load_config)
capability = current_platform.get_device_capability() # type: ignore capability_tuple = current_platform.get_device_capability()
if capability is not None: if capability_tuple is not None:
capability = capability[0] * 10 + capability[1] capability = capability_tuple.to_int()
if capability < quant_config.get_min_capability(): if capability < quant_config.get_min_capability():
raise ValueError( raise ValueError(
f"The quantization method {model_config.quantization} " f"The quantization method {model_config.quantization} "
...@@ -208,6 +212,22 @@ class BaseModelLoader(ABC): ...@@ -208,6 +212,22 @@ class BaseModelLoader(ABC):
class DefaultModelLoader(BaseModelLoader): class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk.""" """Model loader that can load different file types from disk."""
@dataclasses.dataclass
class Source:
"""A source for weights."""
model_or_path: str
"""The model ID or path."""
revision: Optional[str]
"""The optional model revision."""
prefix: str = ""
"""A prefix to prepend to all weights."""
fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""
def __init__(self, load_config: LoadConfig): def __init__(self, load_config: LoadConfig):
super().__init__(load_config) super().__init__(load_config)
if load_config.model_loader_extra_config: if load_config.model_loader_extra_config:
...@@ -314,17 +334,16 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -314,17 +334,16 @@ class DefaultModelLoader(BaseModelLoader):
return hf_folder, hf_weights_files, use_safetensors return hf_folder, hf_weights_files, use_safetensors
def _get_weights_iterator( def _get_weights_iterator(
self, model_name_or_path: str, revision: Optional[str], self, source: "Source"
fall_back_to_pt: bool
) -> Generator[Tuple[str, torch.Tensor], None, None]: ) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format.""" """Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision, fall_back_to_pt) source.model_or_path, source.revision, source.fall_back_to_pt)
if self.load_config.load_format == LoadFormat.NPCACHE: if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints # Currently np_cache only support *.bin checkpoints
assert use_safetensors is False assert use_safetensors is False
weights_iterator = np_cache_weights_iterator( weights_iterator = np_cache_weights_iterator(
model_name_or_path, self.load_config.download_dir, hf_folder, source.model_or_path, self.load_config.download_dir, hf_folder,
hf_weights_files) hf_weights_files)
elif use_safetensors: elif use_safetensors:
weights_iterator = safetensors_weights_iterator(hf_weights_files) weights_iterator = safetensors_weights_iterator(hf_weights_files)
...@@ -342,7 +361,29 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -342,7 +361,29 @@ class DefaultModelLoader(BaseModelLoader):
xm.mark_step() xm.mark_step()
weights_iterator = _xla_weights_iterator(weights_iterator) weights_iterator = _xla_weights_iterator(weights_iterator)
return weights_iterator
# Apply the prefix.
return ((source.prefix + name, tensor)
for (name, tensor) in weights_iterator)
def _get_all_weights(
self,
model_config: ModelConfig,
model: nn.Module,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
primary_weights = DefaultModelLoader.Source(
model_config.model,
model_config.revision,
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
True))
yield from self._get_weights_iterator(primary_weights)
secondary_weights = cast(Iterable[DefaultModelLoader.Source],
getattr(model, "secondary_weights", ()))
for source in secondary_weights:
yield from self._get_weights_iterator(source)
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, self._prepare_weights(model_config.model,
...@@ -361,13 +402,8 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -361,13 +402,8 @@ class DefaultModelLoader(BaseModelLoader):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, cache_config, lora_config, cache_config,
scheduler_config) scheduler_config)
model.load_weights(
self._get_weights_iterator(model_config.model, model.load_weights(self._get_all_weights(model_config, model))
model_config.revision,
fall_back_to_pt=getattr(
model,
"fall_back_to_pt_during_load",
True)), )
for _, module in model.named_modules(): for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None) quant_method = getattr(module, "quant_method", None)
...@@ -692,6 +728,8 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -692,6 +728,8 @@ class ShardedStateLoader(BaseModelLoader):
class BitsAndBytesModelLoader(BaseModelLoader): class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quantization.""" """Model loader to load model weights with BitAndBytes quantization."""
# TODO: these module names are for Llama only,
# change so that it works with other models as well
default_target_modules = [ default_target_modules = [
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
"o_proj" "o_proj"
...@@ -816,12 +854,12 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -816,12 +854,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# only load the bitsandbytes module when needed # only load the bitsandbytes module when needed
try: try:
import bitsandbytes import bitsandbytes
if bitsandbytes.__version__ < "0.42.0": if bitsandbytes.__version__ < "0.44.0":
raise ImportError("bitsandbytes version is wrong. Please " raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0.") "install bitsandbytes>=0.44.0.")
except ImportError as err: except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.42.0 via " raise ImportError("Please install bitsandbytes>=0.44.0 via "
"`pip install bitsandbytes>=0.42.0` to use " "`pip install bitsandbytes>=0.44.0` to use "
"bitsandbytes quantizer.") from err "bitsandbytes quantizer.") from err
hf_weights_files, use_safetensors = self._prepare_weights( hf_weights_files, use_safetensors = self._prepare_weights(
...@@ -914,13 +952,44 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -914,13 +952,44 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def _unquantized_generator(self, hf_weights_files, use_safetensors, def _unquantized_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator: quant_state_dict) -> Generator:
from bitsandbytes.functional import quantize_4bit from bitsandbytes.functional import quantize_4bit
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for weight_name, weight_tensor in self._hf_weight_iter( for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors): hf_weights_files, use_safetensors):
if any(target_module in weight_name if any(target_module in weight_name
for target_module in self.target_modules): for target_module in self.target_modules):
weight_name = weight_name.replace(".weight", ".qweight") weight_name = weight_name.replace(".weight", ".qweight")
# weight partitions of different modules occur at
# different dimensions
# TODO: these module names are for Llama only,
# change so that it works with other models as well
if 'down_proj' in weight_name or 'o_proj' in weight_name:
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[...,
start_index:end_index]
else:
total_size = weight_tensor.size(0)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[start_index:end_index,
...]
# bitsandbytes requires data in GPU # bitsandbytes requires data in GPU
loaded_weight = weight_tensor.cuda().data if weight_sub_tensor.is_cuda:
loaded_weight = weight_sub_tensor
else:
loaded_weight = weight_sub_tensor.cuda()
# remove the following after the issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
if loaded_weight.is_contiguous() is False:
loaded_weight = loaded_weight.contiguous()
with set_default_torch_dtype(torch.float32): with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit( processed_weight, quant_state = quantize_4bit(
loaded_weight, loaded_weight,
...@@ -961,6 +1030,13 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -961,6 +1030,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
f"BitsAndBytes loader does not support {quant_method} " f"BitsAndBytes loader does not support {quant_method} "
"quantization") "quantization")
# The quant_states in pre_quantized models cannot work with a split
# weight tensor. So TP does not work with pre_quantized bnb models.
if pre_quant and get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"Prequant BitsAndBytes models with TP is not supported."
"Please try with PP.")
load_8bit = False load_8bit = False
if pre_quant: if pre_quant:
load_8bit = quant_config.get('load_in_8bit', False) load_8bit = quant_config.get('load_in_8bit', False)
......
...@@ -408,9 +408,7 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: ...@@ -408,9 +408,7 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
"inferred as vLLM models, so setting vllm_tensorized=True is " "inferred as vLLM models, so setting vllm_tensorized=True is "
"only necessary for models serialized prior to this change.") "only necessary for models serialized prior to this change.")
return True return True
if (".vllm_tensorized_marker" in deserializer): return ".vllm_tensorized_marker" in deserializer
return True
return False
def serialize_vllm_model( def serialize_vllm_model(
......
...@@ -37,13 +37,7 @@ def get_model_architecture( ...@@ -37,13 +37,7 @@ def get_model_architecture(
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
mixtral_supported = ["fp8", "compressed-tensors"] mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"]
# for gptq_marlin, only run fused MoE for int4
if model_config.quantization == "gptq_marlin":
hf_quant_config = getattr(model_config.hf_config,
"quantization_config", None)
if hf_quant_config and hf_quant_config.get("bits") == 4:
mixtral_supported.append("gptq_marlin")
if (model_config.quantization is not None if (model_config.quantization is not None
and model_config.quantization not in mixtral_supported and model_config.quantization not in mixtral_supported
......
...@@ -43,13 +43,15 @@ _GENERATION_MODELS = { ...@@ -43,13 +43,15 @@ _GENERATION_MODELS = {
"MptForCausalLM": ("mpt", "MPTForCausalLM"), "MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
...@@ -59,6 +61,7 @@ _GENERATION_MODELS = { ...@@ -59,6 +61,7 @@ _GENERATION_MODELS = {
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
...@@ -80,12 +83,14 @@ _MULTIMODAL_MODELS = { ...@@ -80,12 +83,14 @@ _MULTIMODAL_MODELS = {
("chameleon", "ChameleonForConditionalGeneration"), ("chameleon", "ChameleonForConditionalGeneration"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"InternVLChatModel": ("internvl", "InternVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"),
"LlavaForConditionalGeneration": "LlavaForConditionalGeneration": ("llava",
("llava", "LlavaForConditionalGeneration"), "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration": ("llava_next",
"LlavaNextForConditionalGeneration"), "LlavaNextForConditionalGeneration"),
"LlavaNextVideoForConditionalGeneration": "LlavaNextVideoForConditionalGeneration":
("llava_next_video", "LlavaNextVideoForConditionalGeneration"), ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
"LlavaOnevisionForConditionalGeneration":
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
"MiniCPMV": ("minicpmv", "MiniCPMV"), "MiniCPMV": ("minicpmv", "MiniCPMV"),
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"), "PaliGemmaForConditionalGeneration"),
...@@ -96,6 +101,8 @@ _MULTIMODAL_MODELS = { ...@@ -96,6 +101,8 @@ _MULTIMODAL_MODELS = {
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration": ("qwen2_vl",
"Qwen2VLForConditionalGeneration"), "Qwen2VLForConditionalGeneration"),
"UltravoxModel": ("ultravox", "UltravoxModel"), "UltravoxModel": ("ultravox", "UltravoxModel"),
"MllamaForConditionalGeneration": ("mllama",
"MllamaForConditionalGeneration"),
} }
_CONDITIONAL_GENERATION_MODELS = { _CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"), "BartModel": ("bart", "BartForConditionalGeneration"),
......
...@@ -848,11 +848,13 @@ class BartForConditionalGeneration(nn.Module): ...@@ -848,11 +848,13 @@ class BartForConditionalGeneration(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
*,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Args: Args:
......
"""Minimal implementation of BlipVisionModel intended to be only used """Minimal implementation of BlipVisionModel intended to be only used
within a vision language model.""" within a vision language model."""
from array import array from typing import Iterable, Optional, Tuple, Union
from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -17,9 +16,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -17,9 +16,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import SequenceData
try: try:
from xformers import ops as xops from xformers import ops as xops
...@@ -53,6 +53,7 @@ def get_max_blip_image_tokens( ...@@ -53,6 +53,7 @@ def get_max_blip_image_tokens(
def dummy_seq_data_for_blip( def dummy_seq_data_for_blip(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig], hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
seq_len: int, seq_len: int,
num_images: int,
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[int] = None, image_feature_size_override: Optional[int] = None,
...@@ -62,11 +63,10 @@ def dummy_seq_data_for_blip( ...@@ -62,11 +63,10 @@ def dummy_seq_data_for_blip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, return SequenceData.from_token_counts(
[image_token_id]) * image_feature_size (image_token_id, image_feature_size * num_images),
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, (0, seq_len - image_feature_size * num_images),
[0]) * (seq_len - image_feature_size) )
return SequenceData(token_ids)
def dummy_image_for_blip( def dummy_image_for_blip(
...@@ -343,6 +343,10 @@ class BlipVisionModel(nn.Module): ...@@ -343,6 +343,10 @@ class BlipVisionModel(nn.Module):
num_hidden_layers_override: Optional[int] = None): num_hidden_layers_override: Optional[int] = None):
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
self.config = config self.config = config
self.embeddings = BlipVisionEmbeddings(config) self.embeddings = BlipVisionEmbeddings(config)
...@@ -351,11 +355,61 @@ class BlipVisionModel(nn.Module): ...@@ -351,11 +355,61 @@ class BlipVisionModel(nn.Module):
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
) )
self.post_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {config.num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
elif len(self.encoder.layers) == config.num_hidden_layers:
self.post_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
else:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values) hidden_states = self.embeddings(pixel_values)
hidden_states = self.encoder(inputs_embeds=hidden_states) hidden_states = self.encoder(inputs_embeds=hidden_states)
if self.post_layernorm is None:
return hidden_states
return self.post_layernorm(hidden_states) return self.post_layernorm(hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
params_dict = dict(self.named_parameters())
layer_count = len(self.encoder.layers)
for name, loaded_weight in weights:
# post_layernorm is not needed in BlipVisionModel
if (name.startswith("post_layernorm")
and self.post_layernorm is None):
continue
# omit layers when num_hidden_layers_override is set
if name.startswith("encoder.layers"):
layer_idx = int(name.split(".")[2])
if layer_idx >= layer_count:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -11,25 +10,18 @@ from vllm.attention import AttentionMetadata ...@@ -11,25 +10,18 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.opt import OPTModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import IntermediateTensors, SequenceData
SequenceData)
from .blip import (BlipVisionModel, dummy_image_for_blip, from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens) get_max_blip_image_tokens)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings from .utils import (group_weights_with_prefix, init_vllm_registered_model,
merge_multimodal_embeddings)
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
# We use this internally as placeholders since there is no image token # We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo # defined on the HuggingFace repo
...@@ -429,11 +421,10 @@ def dummy_seq_data_for_blip2( ...@@ -429,11 +421,10 @@ def dummy_seq_data_for_blip2(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, return SequenceData.from_token_counts(
[image_token_id]) * image_feature_size * num_images (image_token_id, image_feature_size * num_images),
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, (0, seq_len - image_feature_size * num_images),
[0]) * (seq_len - image_feature_size * num_images) )
return SequenceData(token_ids)
def dummy_data_for_blip2(ctx: InputContext, seq_len: int, def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
...@@ -494,9 +485,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -494,9 +485,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
super().__init__() super().__init__()
# currently all existing BLIP-2 models have `tie_word_embeddings`
# enabled
assert config.tie_word_embeddings
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
...@@ -517,17 +505,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -517,17 +505,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
bias=True, bias=True,
) )
self.quant_config = quant_config self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
self.language_model = OPTModel(config.text_config, cache_config,
quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size)
self.sampler = Sampler()
def get_lm_head(self):
return self.language_model.decoder.embed_tokens
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size h = w = self.config.vision_config.image_size
...@@ -656,7 +635,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -656,7 +635,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
if image_input is not None: if image_input is not None:
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
...@@ -666,11 +646,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -666,11 +646,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
else: else:
inputs_embeds = None inputs_embeds = None
hidden_states = self.language_model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
...@@ -679,56 +659,46 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -679,56 +659,46 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.get_lm_head(), hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
return logits
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# only doing this for language model part for now. # prepare weight iterators for components
stacked_params_mapping = [ weights_group = group_weights_with_prefix(weights)
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), # load vision encoder
("qkv_proj", "k_proj", "k"), self.vision_model.load_weights(weights_group["vision_model"])
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0), # load query tokens
("gate_up_proj", "up_proj", 1), for name, loaded_weight in weights_group["query_tokens"]:
] assert name == ""
params_dict = dict(self.named_parameters()) param = self.query_tokens
weight_loader = getattr(param, "weight_loader",
for name, loaded_weight in weights: default_weight_loader)
if "lm_head.weight" in name: weight_loader(param, loaded_weight)
continue
if "rotary_emb.inv_freq" in name: # load qformer
continue qformer_params_dict = dict(self.qformer.named_parameters())
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): for name, loaded_weight in weights_group["qformer"]:
if key_to_modify in name: param = qformer_params_dict[name]
name = name.replace(key_to_modify, new_key) weight_loader = getattr(param, "weight_loader",
use_default_weight_loading = False default_weight_loader)
if "vision" in name: weight_loader(param, loaded_weight)
if self.vision_model is not None:
# BlipVisionModel does not need sharding # load mlp projector
use_default_weight_loading = True mlp_params_dict = dict(self.language_projection.named_parameters())
else: for name, loaded_weight in weights_group["language_projection"]:
for (param_name, weight_name, param = mlp_params_dict[name]
shard_id) in stacked_params_mapping: weight_loader = getattr(param, "weight_loader",
if weight_name not in name: default_weight_loader)
continue weight_loader(param, loaded_weight)
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader # load llm backbone
weight_loader(param, loaded_weight, shard_id) self.language_model.load_weights(weights_group["language_model"])
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
from array import array
from functools import cached_property from functools import cached_property
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict) Tuple, TypedDict)
...@@ -13,7 +12,6 @@ from vllm.attention import Attention, AttentionMetadata ...@@ -13,7 +12,6 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -32,14 +30,11 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -32,14 +30,11 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import IntermediateTensors, SequenceData
SequenceData)
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
logger = init_logger(__name__)
# These configs are not part of the model config but the preprocessor # These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now. # and processor files, so we hardcode them in the model file for now.
CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512 CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
...@@ -72,11 +67,10 @@ def dummy_seq_data_for_chameleon( ...@@ -72,11 +67,10 @@ def dummy_seq_data_for_chameleon(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, return SequenceData.from_token_counts(
[image_token_id]) * image_feature_size * num_images (image_token_id, image_feature_size * num_images),
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, (0, seq_len - image_feature_size * num_images),
[0]) * (seq_len - image_feature_size * num_images) )
return SequenceData(token_ids)
def dummy_image_for_chameleon( def dummy_image_for_chameleon(
......
"""Minimal implementation of CLIPVisionModel intended to be only used """Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model.""" within a vision language model."""
from array import array
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import SequenceData
try: try:
from xformers import ops as xops from xformers import ops as xops
...@@ -62,11 +62,10 @@ def dummy_seq_data_for_clip( ...@@ -62,11 +62,10 @@ def dummy_seq_data_for_clip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, return SequenceData.from_token_counts(
[image_token_id]) * image_feature_size * num_images (image_token_id, image_feature_size * num_images),
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, (0, seq_len - image_feature_size * num_images),
[0]) * (seq_len - image_feature_size * num_images) )
return SequenceData(token_ids)
def dummy_image_for_clip( def dummy_image_for_clip(
...@@ -86,6 +85,24 @@ def dummy_image_for_clip( ...@@ -86,6 +85,24 @@ def dummy_image_for_clip(
return {"image": image if num_images == 1 else [image] * num_images} return {"image": image if num_images == 1 else [image] * num_images}
def dummy_video_for_clip(
hf_config: CLIPVisionConfig,
num_frames: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
pil_frame = dummy_image_for_clip(
hf_config,
num_images=1,
image_width_override=image_width_override,
image_height_override=image_height_override)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
mm_data = {"video": mm_data_per_video}
return mm_data
def input_processor_for_clip( def input_processor_for_clip(
model_config: ModelConfig, model_config: ModelConfig,
hf_config: CLIPVisionConfig, hf_config: CLIPVisionConfig,
...@@ -393,6 +410,7 @@ class CLIPVisionModel(nn.Module): ...@@ -393,6 +410,7 @@ class CLIPVisionModel(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None): num_hidden_layers_override: Optional[int] = None):
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
...@@ -402,10 +420,6 @@ class CLIPVisionModel(nn.Module): ...@@ -402,10 +420,6 @@ class CLIPVisionModel(nn.Module):
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override) num_hidden_layers_override=num_hidden_layers_override)
@property
def _require_post_layernorm(self) -> bool:
return self.vision_model.post_layernorm is not None
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.vision_model(pixel_values) return self.vision_model(pixel_values)
...@@ -427,12 +441,12 @@ class CLIPVisionModel(nn.Module): ...@@ -427,12 +441,12 @@ class CLIPVisionModel(nn.Module):
for name, loaded_weight in weights: for name, loaded_weight in weights:
# post_layernorm is not needed in CLIPVisionModel # post_layernorm is not needed in CLIPVisionModel
if ("vision_model.post_layernorm" in name if (name.startswith("vision_model.post_layernorm")
and not self._require_post_layernorm): and self.vision_model.post_layernorm is None):
continue continue
# omit layers when num_hidden_layers_override is set # omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name: if name.startswith("vision_model.encoder.layers"):
layer_idx = int(name.split(".")[3]) layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count: if layer_idx >= layer_count:
continue continue
......
...@@ -7,9 +7,8 @@ import torch.nn as nn ...@@ -7,9 +7,8 @@ import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size)
tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
...@@ -22,7 +21,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -22,7 +21,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig
...@@ -54,13 +52,7 @@ class DbrxRouter(nn.Module): ...@@ -54,13 +52,7 @@ class DbrxRouter(nn.Module):
return router_logits return router_logits
class DbrxExperts(nn.Module): class DbrxExperts(FusedMoE):
"""A tensor-parallel MoE implementation for DBRX.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__( def __init__(
self, self,
...@@ -68,49 +60,24 @@ class DbrxExperts(nn.Module): ...@@ -68,49 +60,24 @@ class DbrxExperts(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
): ):
super().__init__() super().__init__(
num_experts=config.ffn_config.moe_num_experts,
top_k=config.ffn_config.moe_top_k,
hidden_size=config.d_model,
intermediate_size=config.ffn_config.ffn_hidden_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=get_tensor_model_parallel_world_size(),
)
self.config = config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.ffn_config.moe_num_experts
self.top_k = config.ffn_config.moe_top_k
self.d_model = config.d_model self.d_model = config.d_model
self.intermediate_size = (config.ffn_config.ffn_hidden_size // self.intermediate_size = (self.config.ffn_config.ffn_hidden_size //
self.tp_size) self.tp_size)
if params_dtype is None: # Define custom weight loader for dbrx model
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.router = DbrxRouter(config, self.params_dtype)
self.ws = nn.Parameter(
torch.empty(
self.num_total_experts,
2 * self.intermediate_size,
self.d_model,
device="cuda",
dtype=self.params_dtype,
))
self.w2s = nn.Parameter(
torch.empty(
self.num_total_experts,
self.d_model,
self.intermediate_size,
device="cuda",
dtype=self.params_dtype,
))
set_weight_attrs(
self.ws,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2s,
{
"weight_loader": self.weight_loader,
},
)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str): weight_name: str):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -140,26 +107,40 @@ class DbrxExperts(nn.Module): ...@@ -140,26 +107,40 @@ class DbrxExperts(nn.Module):
).transpose(1, 2) ).transpose(1, 2)
param_data[:] = loaded_weight[:, :, shard] param_data[:] = loaded_weight[:, :, shard]
class DbrxMoE(nn.Module):
"""A tensor-parallel MoE implementation for DBRX.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.d_model = config.d_model
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.router = DbrxRouter(config, self.params_dtype)
self.experts = DbrxExperts(config=config,
quant_config=quant_config,
params_dtype=self.params_dtype)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.d_model) hidden_states = hidden_states.view(-1, self.d_model)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.router(hidden_states) router_logits = self.router(hidden_states)
final_hidden_states = fused_moe( final_hidden_states = self.experts(hidden_states, router_logits)
hidden_states, return final_hidden_states.view(orig_shape)
self.ws,
self.w2s,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
class DbrxAttention(nn.Module): class DbrxAttention(nn.Module):
...@@ -288,7 +269,7 @@ class DbrxBlock(nn.Module): ...@@ -288,7 +269,7 @@ class DbrxBlock(nn.Module):
super().__init__() super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config, self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
quant_config) quant_config)
self.ffn = DbrxExperts(config, quant_config) self.ffn = DbrxMoE(config, quant_config)
def forward( def forward(
self, self,
...@@ -409,9 +390,10 @@ class DbrxForCausalLM(nn.Module): ...@@ -409,9 +390,10 @@ class DbrxForCausalLM(nn.Module):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = [( expert_params_mapping = [(
"ws" if weight_name in ["w1", "v1"] else "w2s", "w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
f"experts.mlp.{weight_name}", f"mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]] ) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights: for name, loaded_weight in weights:
......
...@@ -44,7 +44,7 @@ class EAGLE(nn.Module): ...@@ -44,7 +44,7 @@ class EAGLE(nn.Module):
self.model = model_cls(self.config.model, *args, **kwargs) self.model = model_cls(self.config.model, *args, **kwargs)
self.fc = nn.Linear(config.model.hidden_size * 2, self.fc = nn.Linear(config.model.hidden_size * 2,
config.model.hidden_size, config.model.hidden_size,
bias=False) bias=getattr(self.config, "bias", False))
self.orig_vocab_size = config.vocab_size self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size self.truncated_vocab_size = config.truncated_vocab_size
...@@ -136,10 +136,18 @@ class EAGLE(nn.Module): ...@@ -136,10 +136,18 @@ class EAGLE(nn.Module):
if self.config.truncated_vocab_size < self.config.vocab_size: if self.config.truncated_vocab_size < self.config.vocab_size:
self.token_map = nn.Parameter(loaded_weight, self.token_map = nn.Parameter(loaded_weight,
requires_grad=False) requires_grad=False)
elif name.startswith("fc."): elif name.startswith("fc.weight"):
weight_loader = getattr(self.fc.weight, "weight_loader", weight_loader = getattr(self.fc.weight, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(self.fc.weight, loaded_weight) weight_loader(self.fc.weight, loaded_weight)
elif name.startswith("fc.bias"):
if self.fc.bias is not None:
weight_loader = getattr(self.fc.bias, "weight_loader",
default_weight_loader)
weight_loader(self.fc.bias, loaded_weight)
else:
raise ValueError("Found bias in the loaded weights "
"but the model config doesn't have bias")
elif name.startswith("model.lm_head.") or name.startswith( elif name.startswith("model.lm_head.") or name.startswith(
"model.model."): "model.model."):
model_weights[name.split("model.", 1)[-1]] = loaded_weight model_weights[name.split("model.", 1)[-1]] = loaded_weight
......
...@@ -28,7 +28,6 @@ from transformers import FuyuConfig, FuyuImageProcessor ...@@ -28,7 +28,6 @@ from transformers import FuyuConfig, FuyuImageProcessor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
...@@ -45,8 +44,6 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, ...@@ -45,8 +44,6 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings from .utils import merge_multimodal_embeddings
logger = init_logger(__name__)
# Cannot find the following 2 numbers from hf config. # Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011 _IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019 _NEWLINE_TOKEN_ID = 71019
...@@ -232,7 +229,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal): ...@@ -232,7 +229,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.text_config.vocab_size
self.image_token_id = _IMAGE_TOKEN_ID self.image_token_id = _IMAGE_TOKEN_ID
self.image_feature_size = config.patch_size**2 * config.num_channels self.image_feature_size = config.patch_size**2 * config.num_channels
......
...@@ -428,7 +428,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA): ...@@ -428,7 +428,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
logits /= self.config.logits_scaling if logits is not None:
logits /= self.config.logits_scaling
return logits return logits
def sample( def sample(
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# Copyright (c) 2023 OpenGVLab # Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
import itertools
import re import re
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -20,7 +19,7 @@ from vllm.config import CacheConfig, MultiModalConfig ...@@ -20,7 +19,7 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -33,8 +32,8 @@ from vllm.utils import is_list_of ...@@ -33,8 +32,8 @@ from vllm.utils import is_list_of
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches) get_clip_num_patches)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, from .utils import (flatten_bn, group_weights_with_prefix,
merge_multimodal_embeddings) init_vllm_registered_model, merge_multimodal_embeddings)
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
...@@ -231,8 +230,9 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -231,8 +230,9 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
tokenizer = cached_get_tokenizer(model_config.tokenizer, tokenizer = cached_get_tokenizer(
trust_remote_code=True) model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
prompt = llm_inputs.get("prompt") prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"] prompt_token_ids = llm_inputs["prompt_token_ids"]
...@@ -279,8 +279,9 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): ...@@ -279,8 +279,9 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
use_thumbnail=use_thumbnail) for img in data use_thumbnail=use_thumbnail) for img in data
] ]
model_config = ctx.model_config model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer, tokenizer = cached_get_tokenizer(
trust_remote_code=True) model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_token_id = tokenizer.encode(IMG_CONTEXT, image_token_id = tokenizer.encode(IMG_CONTEXT,
add_special_tokens=False, add_special_tokens=False,
return_tensors="pt")[0] return_tensors="pt")[0]
...@@ -299,8 +300,9 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int, ...@@ -299,8 +300,9 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
model_config = ctx.model_config model_config = ctx.model_config
hf_config = ctx.get_hf_config() hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
tokenizer = cached_get_tokenizer(model_config.tokenizer, tokenizer = cached_get_tokenizer(
trust_remote_code=True) model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
seq_data = dummy_seq_data_for_clip( seq_data = dummy_seq_data_for_clip(
vision_config, vision_config,
...@@ -377,6 +379,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal): ...@@ -377,6 +379,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
if hasattr(self.language_model, "sampler"):
self.sampler = self.language_model.sampler
else:
self.sampler = Sampler()
def pixel_shuffle(self, x, scale_factor=0.5): def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size() n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale # N, W, H, C --> N, W, H * scale, C // scale
...@@ -518,21 +525,18 @@ class InternVLChatModel(nn.Module, SupportsMultiModal): ...@@ -518,21 +525,18 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components # prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) weights_group = group_weights_with_prefix(weights)
# load vision encoder # load vision encoder
vit_weights = filter_weights(vit_weights, "vision_model") self.vision_model.load_weights(weights_group["vision_model"])
self.vision_model.load_weights(vit_weights)
# load mlp projector # load mlp projector
mlp_weights = filter_weights(mlp_weights, "mlp1")
mlp_params_dict = dict(self.mlp1.named_parameters()) mlp_params_dict = dict(self.mlp1.named_parameters())
for name, loaded_weight in mlp_weights: for name, loaded_weight in weights_group["mlp1"]:
param = mlp_params_dict[name] param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load llm backbone # load llm backbone
llm_weights = filter_weights(llm_weights, "language_model") self.language_model.load_weights(weights_group["language_model"])
self.language_model.load_weights(llm_weights)
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -26,8 +25,8 @@ from .interfaces import SupportsMultiModal ...@@ -26,8 +25,8 @@ from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens, dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip) input_processor_for_siglip)
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, from .utils import (flatten_bn, group_weights_with_prefix,
merge_multimodal_embeddings) init_vllm_registered_model, merge_multimodal_embeddings)
class LlavaImagePixelInputs(TypedDict): class LlavaImagePixelInputs(TypedDict):
...@@ -393,21 +392,18 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -393,21 +392,18 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components # prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) weights_group = group_weights_with_prefix(weights)
# load vision encoder # load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower") self.vision_tower.load_weights(weights_group["vision_tower"])
self.vision_tower.load_weights(vit_weights)
# load mlp projector # load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights: for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name] param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load llm backbone # load llm backbone
llm_weights = filter_weights(llm_weights, "language_model") self.language_model.load_weights(weights_group["language_model"])
self.language_model.load_weights(llm_weights)
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -13,7 +12,6 @@ from typing_extensions import NotRequired ...@@ -13,7 +12,6 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -30,15 +28,8 @@ from .llava import LlavaMultiModalProjector ...@@ -30,15 +28,8 @@ from .llava import LlavaMultiModalProjector
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size, dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, from .utils import (flatten_bn, group_weights_with_prefix,
merge_multimodal_embeddings) init_vllm_registered_model, merge_multimodal_embeddings)
logger = init_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
# Result in the max possible feature size (2x2 grid of 336x336px tiles) # Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
...@@ -87,17 +78,19 @@ def _get_llava_next_num_unpadded_features( ...@@ -87,17 +78,19 @@ def _get_llava_next_num_unpadded_features(
current_height = npatches * num_patch_height current_height = npatches * num_patch_height
current_width = npatches * num_patch_width current_width = npatches * num_patch_width
aspect_ratio = original_width / original_height original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height current_aspect_ratio = current_width / current_height
if aspect_ratio > current_aspect_ratio: if original_aspect_ratio > current_aspect_ratio:
new_height = (original_height * current_width) // original_width scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2 padding = (current_height - new_height) // 2
current_height -= padding * 2 current_height -= 2 * padding
else: else:
new_width = (original_width * current_height) // original_height scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2 padding = (current_width - new_width) // 2
current_width -= padding * 2 current_width -= 2 * padding
unpadded_features = current_height * current_width unpadded_features = current_height * current_width
newline_features = current_height newline_features = current_height
...@@ -635,25 +628,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -635,25 +628,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components # prepare weight iterators for components
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee( weights_group = group_weights_with_prefix(weights)
weights, 4)
# load vision encoder # load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower") self.vision_tower.load_weights(weights_group["vision_tower"])
self.vision_tower.load_weights(vit_weights)
# load mlp projector # load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights: for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name] param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load newline # load newline
newline_weights = filter_weights(newline_weights, "image_newline") for name, loaded_weight in weights_group["image_newline"]:
for name, loaded_weight in newline_weights:
assert name == "" assert name == ""
param = self.image_newline param = self.image_newline
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
...@@ -661,5 +650,4 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -661,5 +650,4 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load llm backbone # load llm backbone
llm_weights = filter_weights(llm_weights, "language_model") self.language_model.load_weights(weights_group["language_model"])
self.language_model.load_weights(llm_weights)
import itertools
import math import math
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -12,7 +11,6 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig, ...@@ -12,7 +11,6 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
...@@ -30,11 +28,9 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip ...@@ -30,11 +28,9 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip) dummy_seq_data_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model, from .utils import (group_weights_with_prefix, init_vllm_registered_model,
merge_multimodal_embeddings) merge_multimodal_embeddings)
logger = init_logger(__name__)
# For profile run # For profile run
_MAX_FRAMES_PER_VIDEO = 32 _MAX_FRAMES_PER_VIDEO = 32
_MAX_NUM_VIDEOS = 1 _MAX_NUM_VIDEOS = 1
...@@ -449,23 +445,19 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -449,23 +445,19 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators # prepare weight iterators for components
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee( weights_group = group_weights_with_prefix(weights)
weights, 4)
# load vision encoder # load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower") self.vision_tower.load_weights(weights_group["vision_tower"])
self.vision_tower.load_weights(vit_weights)
# load mlp projector # load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights: for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name] param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load llm backbone # load llm backbone
llm_weights = filter_weights(llm_weights, "language_model") self.language_model.load_weights(weights_group["language_model"])
self.language_model.load_weights(llm_weights)
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from transformers import (CLIPVisionConfig, LlavaOnevisionConfig,
SiglipVisionConfig)
from transformers.models.llava_onevision.modeling_llava_onevision import (
get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_seq_data_for_clip,
dummy_video_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)
logger = init_logger(__name__)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
_MAX_NUM_VIDEOS = 1
class LlavaOnevisionVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, num_frames, num_channels, height, width)`
Note that `num_frames` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
Note that it only supports one video input for one batch.
"""
class LlavaOnevisionImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
image_sizes: NotRequired[torch.Tensor]
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
class LlavaOnevisionImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs,
LlavaOnevisionImageEmbeddingInputs]
LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs,
LlavaOnevisionVideoPixelInputs]
def _get_llava_onevision_image_unppaded_feature_size(height, width, patches,
scale_height,
scale_width):
current_height = patches * scale_height
current_width = patches * scale_width
original_aspect_ratio = width / height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
new_height = int(height * (current_width / width))
padding = (current_height - new_height) // 2
current_height -= padding * 2
else:
new_width = int(width * (current_height / height))
padding = (current_width - new_width) // 2
current_width -= padding * 2
unpadded_features = current_height * current_width
newline_features = current_height
ratio = math.sqrt(current_height * current_width / (9 * patches**2))
if ratio > 1.1:
unpadded_features = int(current_height // ratio) * int(
current_width // ratio)
newline_features = int(current_height // ratio)
return (unpadded_features, newline_features)
def get_llava_onevision_image_feature_size(
hf_config: LlavaOnevisionConfig,
*,
input_height: int,
input_width: int,
) -> int:
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
num_patches = get_clip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = get_clip_image_feature_size(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_patches = get_siglip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = get_siglip_image_feature_size(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
strategy = hf_config.vision_feature_select_strategy
if strategy == "default":
base_feature_size -= 1
elif strategy == "full":
pass
else:
raise ValueError(f"Unexpected select feature strategy: {strategy}")
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size,
)
(
unpadded_feature_size,
newline_feature_size,
) = _get_llava_onevision_image_unppaded_feature_size(
input_height, input_width, num_patches, num_patch_height,
num_patch_width)
return unpadded_feature_size + newline_feature_size + base_feature_size
def get_max_llava_onevision_image_tokens(ctx: InputContext):
return get_llava_onevision_image_feature_size(
ctx.get_hf_config(LlavaOnevisionConfig),
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
def get_llava_onevision_video_frame_feature_size(
hf_config: LlavaOnevisionConfig) -> int:
# Support both CLIPVisionConfig and SiglipVisionConfig
image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
spatial_pool_stride = hf_config.spatial_pool_stride if hasattr(
hf_config, "spatial_pool_stride") else 2
height = width = image_size // patch_size
return math.ceil(height / spatial_pool_stride) * math.ceil(
width / spatial_pool_stride)
def get_llava_onevision_video_tokens(ctx: InputContext,
num_frames: int) -> int:
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
# TODO: support configuring (not supported by HF right now)
num_token_image_newline = 1
tokens_per_frame = get_llava_onevision_video_frame_feature_size(hf_config)
video_feature_size = num_frames * tokens_per_frame + num_token_image_newline
return video_feature_size
def get_max_llava_onevision_video_tokens(ctx: InputContext) -> int:
return get_llava_onevision_video_tokens(ctx, _MAX_FRAMES_PER_VIDEO)
def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
# TODO: support multiple videos
num_videos = mm_counts["video"]
if num_videos > _MAX_NUM_VIDEOS:
raise NotImplementedError(
f"Only {_MAX_NUM_VIDEOS} videos are supported")
# TODO: support configuring the number of frames
num_frames = _MAX_FRAMES_PER_VIDEO
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
)
mm_data = dummy_video_for_clip(vision_config, num_frames=num_frames)
return seq_data, mm_data
elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
)
mm_data = dummy_video_for_siglip(vision_config, num_frames=num_frames)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_when_multimodal_input_image(ctx: InputContext,
llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
width, height = image_data.size
image_feature_size = get_llava_onevision_image_feature_size(
hf_config,
input_height=height,
input_width=width,
)
elif is_list_of(image_data, Image.Image):
image_feature_size = [
get_llava_onevision_image_feature_size(hf_config,
input_height=img.height,
input_width=img.width)
for img in image_data
]
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
model_config,
vision_config,
llm_inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, SiglipVisionConfig):
return input_processor_for_siglip(
model_config,
vision_config,
llm_inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_when_multimodal_input_video(ctx: InputContext,
llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data:
return llm_inputs
video_data = multi_modal_data["video"]
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
if isinstance(video_data, np.ndarray):
# Supports both CLIP and Siglip
num_frames = video_data.shape[0]
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size,
)
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
elif is_list_of(video_data, np.ndarray):
raise NotImplementedError(
"Processing multiple videos is not supported")
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_for_llava_onevision(ctx: InputContext,
llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or ("video" not in multi_modal_data
and "image" not in multi_modal_data):
return llm_inputs
if "image" in multi_modal_data:
return input_processor_when_multimodal_input_image(ctx, llm_inputs)
if "video" in multi_modal_data:
return input_processor_when_multimodal_input_video(ctx, llm_inputs)
msg = "Unsupported multi data type"
raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaOnevisionConfig):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
class LlavaOnevisionMultiModalProjector(nn.Module):
def __init__(self, config: LlavaOnevisionConfig):
super().__init__()
self.linear_1 = nn.Linear(config.vision_config.hidden_size,
config.text_config.hidden_size,
bias=True)
self.act = get_act_fn(config.projector_hidden_act)
self.linear_2 = nn.Linear(config.text_config.hidden_size,
config.text_config.hidden_size,
bias=True)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_input_mapper("video")
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"image", get_max_llava_onevision_image_tokens)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"video", get_max_llava_onevision_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision)
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: LlavaOnevisionConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer
self.vision_tower = _init_vision_tower(config)
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, )
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
f"The expected shape of image sizes per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _validate_image_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return LlavaOnevisionImagePixelInputs(
type="pixel_values",
data=self._validate_image_pixel_values(
flatten_bn(pixel_values)),
image_sizes=self._validate_image_sizes(
flatten_bn(image_sizes, concat=True)),
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}")
return LlavaOnevisionImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
)
raise AssertionError("This line should be unreachable.")
def _validate_video_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[2:])
if actual_dims != expected_dims:
expected_expr = ("num_frames", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values in each video frame "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_video_input(
self,
**kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]:
"""
A legal video input should have the following dimensions:
{
"pixel_values_videos" :
List[b, Tensor(nb_frames, nb_channels, height, width)]
}
"""
pixel_values = kwargs.pop("pixel_values_videos", None)
if pixel_values is None:
return None
if not (is_list_of(pixel_values,
(torch.Tensor)) # different shape videos
or isinstance(pixel_values,
torch.Tensor)): # same shape videos
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaOnevisionVideoPixelInputs(
type="pixel_values_videos",
data=pixel_values,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
if "pixel_values" in kwargs:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if "pixel_values_videos" in kwargs:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
return modalities
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
)
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
def _merge_image_patch_embeddings(self,
image_size: torch.Tensor,
patch_embeddings: torch.Tensor,
*,
image_newline=None,
vision_aspect_ratio="anyres_max_9",
strategy: str) -> torch.Tensor:
if strategy == "flat":
return patch_embeddings.flatten(0, 1)
if strategy.startswith("spatial"):
height = width = self.config.vision_config.image_size \
// self.config.vision_config.patch_size
base_patch_embeds = patch_embeddings[0]
if height * width != base_patch_embeds.shape[0]:
raise ValueError(
"The number of patches is not consistent with the "
"image size.")
if patch_embeddings.shape[0] > 1:
other_patch_embeds = patch_embeddings[1:]
# Move to CPU to avoid floating-point errors
orig_height, orig_width = image_size.tolist()
# image_aspect_ratio == "anyres"
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
(orig_height, orig_width),
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
num_patches = num_patch_height * num_patch_width
# Image patches might be padded for batch processing
other_patch_embeds = other_patch_embeds[:num_patches] \
.view(num_patch_height, num_patch_width, height, width, -1)
if "unpad" in strategy:
other_patch_embeds = other_patch_embeds \
.permute(4, 0, 2, 1, 3).contiguous() \
.flatten(1, 2).flatten(2, 3)
other_patch_embeds = unpad_image(other_patch_embeds,
(orig_height, orig_width))
max_num_patches = int(
vision_aspect_ratio.removeprefix("anyres_max_"))
channels, curr_height, curr_width = other_patch_embeds.shape
ratio = math.sqrt(curr_height * curr_width /
(max_num_patches * height**2))
if ratio > 1.1:
other_patch_embeds = other_patch_embeds[None]
other_patch_embeds = nn.functional.interpolate(
other_patch_embeds, [
int(curr_height // ratio),
int(curr_width // ratio)
],
mode="bilinear")[0]
if image_newline is not None:
other_patch_embeds = torch.cat(
(
other_patch_embeds,
image_newline[:, None, None] \
.expand(*other_patch_embeds.shape[:-1], 1) \
.to(other_patch_embeds.device),
),
dim=-1)
other_patch_embeds = other_patch_embeds \
.flatten(1, 2).transpose(0, 1)
else:
other_patch_embeds = other_patch_embeds \
.permute(0, 2, 1, 3, 4).contiguous() \
.flatten(0, 3)
merged_patch_embeddings = torch.cat(
(base_patch_embeds, other_patch_embeds), dim=0)
else:
if "unpad" in strategy:
merged_patch_embeddings = torch.cat(
(base_patch_embeds,
self.image_newline[None] \
.to(base_patch_embeds.device)
), dim=0)
else:
merged_patch_embeddings = base_patch_embeds
return merged_patch_embeddings
raise ValueError(f"Unexpected patch merge strategy: {strategy}")
def _process_image_pixels(
self,
inputs: LlavaOnevisionImagePixelInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
assert self.vision_tower is not None
pixel_values = inputs["data"]
if isinstance(pixel_values, torch.Tensor):
b, num_patches, c, h, w = pixel_values.shape
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)
stacked_patch_embeddings = self.multi_modal_projector(
stacked_image_features)
return stacked_patch_embeddings.view(
b, num_patches, *stacked_patch_embeddings.shape[1:])
num_patches_per_batch = [v.shape[0] for v in pixel_values]
stacked_pixel_values = torch.cat(pixel_values)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)
return [
self.multi_modal_projector(image_features) for image_features in
torch.split(stacked_image_features, num_patches_per_batch)
]
def _process_image_input(
self,
image_input: LlavaOnevisionImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
if image_input["type"] == "image_embeds":
return [image_input["data"]]
patch_embeddings = self._process_image_pixels(image_input)
image_sizes = image_input.get("image_sizes")
if image_sizes is None:
batch_size = len(image_input["data"])
vision_config = self.config.vision_config
default_height = default_width = vision_config.image_size
image_sizes = torch.as_tensor([[default_height, default_width]
for _ in range(batch_size)])
return [
self._merge_image_patch_embeddings(
image_sizes[i],
patch_features_batch,
image_newline=self.image_newline,
strategy="spatial_unpad")
for i, patch_features_batch in enumerate(patch_embeddings)
]
def _video_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
b, num_videos, frames, c, h, w = pixel_values.shape
assert (num_videos == _MAX_NUM_VIDEOS)
pixel_values = pixel_values.reshape(b * num_videos * frames, c, h, w)
video_features = vision_tower(pixel_values)
video_features = self._select_image_features(
video_features,
strategy=self.config.vision_feature_select_strategy,
)
video_features = self.multi_modal_projector(video_features)
video_features = self.apply_pooling(video_features)
video_features = video_features.reshape(
b, frames * video_features.shape[1], -1)
image_newline = self.image_newline[None, None, :].repeat(b, 1, 1).to(
video_features.device)
video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1)
return video_features
def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
assert self.vision_tower is not None
video_pixels = inputs["data"]
# TODO: support multiple videos per input
if isinstance(video_pixels, torch.Tensor):
stacked_embeddings = self._video_pixels_to_features(
self.vision_tower, video_pixels)
return stacked_embeddings
else:
raise ValueError(
f"Unsupported type of video input {type(video_pixels)}")
def apply_pooling(self, image_features, stride=2):
vision_config = self.config.vision_config
height = width = vision_config.image_size // vision_config.patch_size
batch_frames, _, dim = image_features.shape
image_features = image_features.view(batch_frames, height, width, -1)
image_features = image_features.permute(0, 3, 1, 2)
# TODO support other pooling types config
height, width = image_features.shape[2:]
scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)]
image_feature = nn.functional.interpolate(image_features,
size=scaled_shape,
mode='bilinear')
image_feature = image_feature.permute(0, 2, 3, 1)
image_feature = image_feature.view(batch_frames, -1, dim)
return image_feature
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
"""Run forward pass for LlaVA-Onevision.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values_videos: Pixels in each frames for each input videos.
"""
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
# merge video embeddings into input embeddings
if modalities:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
if "images" in modalities:
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
if "videos" in modalities:
video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
None,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
# load vision encoder
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
...@@ -270,38 +270,47 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -270,38 +270,47 @@ class MiniCPMDecoderLayer(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.cache_config = cache_config
self.quant_config = quant_config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) self.rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) self.rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", self.max_position_embeddings = getattr(config,
8192) "max_position_embeddings", 8192)
self._init_attn_block()
self._init_ffn_block()
def _init_attn_block(self):
self.input_layernorm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
self.self_attn = MiniCPMAttention( self.self_attn = MiniCPMAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=self.config.num_attention_heads,
num_kv_heads=config.num_key_value_heads, num_kv_heads=self.config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=self.rope_theta,
rope_scaling=rope_scaling, rope_scaling=self.rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
cache_config=cache_config, cache_config=self.cache_config,
quant_config=quant_config, quant_config=self.quant_config,
) )
def _init_ffn_block(self):
self.post_attention_layernorm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
self.num_experts = getattr(self.config, "num_experts", 0) self.num_experts = getattr(self.config, "num_experts", 0)
if self.num_experts == 0: if self.num_experts == 0:
self.mlp = MiniCPMMLP( self.mlp = MiniCPMMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=self.config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=self.config.hidden_act,
quant_config=quant_config, quant_config=self.quant_config,
) )
else: else:
self.mlp = MiniCPMMoE(num_experts=config.num_experts, self.mlp = MiniCPMMoE(
top_k=config.num_experts_per_tok, num_experts=self.config.num_experts,
hidden_size=config.hidden_size, top_k=self.config.num_experts_per_tok,
intermediate_size=config.intermediate_size) hidden_size=self.config.hidden_size,
self.input_layernorm = RMSNorm(config.hidden_size, intermediate_size=self.config.intermediate_size)
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward( def forward(
self, self,
...@@ -344,6 +353,8 @@ class MiniCPMModel(nn.Module): ...@@ -344,6 +353,8 @@ class MiniCPMModel(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.cache_config = cache_config
self.quant_config = quant_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0 (lora_config.max_loras or 1)) if lora_config else 0
...@@ -354,11 +365,15 @@ class MiniCPMModel(nn.Module): ...@@ -354,11 +365,15 @@ class MiniCPMModel(nn.Module):
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self._init_layers()
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def _init_layers(self):
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MiniCPMDecoderLayer(config, cache_config, quant_config) MiniCPMDecoderLayer(self.config, self.cache_config,
for _ in range(config.num_hidden_layers) self.quant_config)
for _ in range(self.config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
embedding = self.embed_tokens(input_ids) embedding = self.embed_tokens(input_ids)
...@@ -431,13 +446,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -431,13 +446,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.cache_config = cache_config
self.quant_config = quant_config
self.num_experts = getattr(self.config, "num_experts", 0) self.num_experts = getattr(self.config, "num_experts", 0)
self.quant_config = quant_config self._init_model()
self.model = MiniCPMModel(config,
cache_config,
quant_config,
lora_config=lora_config)
unpadded_vocab_size = config.vocab_size unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size unpadded_vocab_size += lora_config.lora_extra_vocab_size
...@@ -458,6 +471,12 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -458,6 +471,12 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
def _init_model(self):
self.model = MiniCPMModel(config=self.config,
cache_config=self.cache_config,
quant_config=self.quant_config,
lora_config=self.lora_config)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2024 The ModelBest team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM3 model compatible with HuggingFace weights."""
from typing import Any, Dict, Optional
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer,
MiniCPMForCausalLM,
MiniCPMModel)
class MiniCPM3Attention(nn.Module):
def __init__(
self,
config,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int,
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert self.num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config)
self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size,
self.kv_lora_rank +
self.qk_rope_head_dim,
bias=False,
quant_config=quant_config)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config)
# O projection.
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config)
self.rotary_emb = get_rope(
self.qk_rope_head_dim,
rotary_dim=self.qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_local_heads,
self.qk_head_dim,
self.scaling,
num_kv_heads=self.num_local_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
q, _ = self.q_a_proj(hidden_states)
q = self.q_a_layernorm(q)
q, _ = self.q_b_proj(q)
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
latent_cache, _ = self.kv_a_proj_with_mqa(hidden_states)
kv_a, _ = latent_cache.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv, _ = self.kv_b_proj(kv_a)
kv = kv.view(-1, self.num_local_heads,
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:]
q_pe, k_pe = self.rotary_emb(
positions,
q_pe.reshape(-1, self.num_local_heads * self.qk_rope_head_dim),
k_pe.reshape(-1, self.qk_rope_head_dim))
q_pe = q_pe.view(-1, self.num_local_heads, self.qk_rope_head_dim)
k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim:] = k_pe
q = q.reshape(-1, self.num_local_heads * self.qk_head_dim)
k = k.view(-1, self.num_local_heads * self.qk_head_dim)
v = torch.nn.functional.pad(
v, [0, self.qk_head_dim - self.v_head_dim],
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = attn_output.view(
-1, self.num_local_heads,
self.qk_head_dim)[..., :self.v_head_dim].reshape(
-1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output)
return output
class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):
def _init_attn_block(self):
self.input_layernorm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
self.self_attn = MiniCPM3Attention(
config=self.config,
hidden_size=self.hidden_size,
num_heads=self.config.num_attention_heads,
qk_nope_head_dim=self.config.qk_nope_head_dim,
qk_rope_head_dim=self.config.qk_rope_head_dim,
v_head_dim=self.config.v_head_dim,
q_lora_rank=self.config.q_lora_rank,
kv_lora_rank=self.config.kv_lora_rank,
rope_theta=self.rope_theta,
rope_scaling=self.rope_scaling,
max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config,
quant_config=self.quant_config,
)
class MiniCPM3Model(MiniCPMModel):
def _init_layers(self):
self.layers = nn.ModuleList([
MiniCPM3DecoderLayer(self.config, self.cache_config,
self.quant_config)
for _ in range(self.config.num_hidden_layers)
])
class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
def _init_model(self):
self.model = MiniCPM3Model(config=self.config,
cache_config=self.cache_config,
quant_config=self.quant_config,
lora_config=self.lora_config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment