"csrc/quantization/w8a8/fp8/amd/quant_utils.cuh" did not exist on "08a1a1121d83a8b57a88cdec91e8ee15abb517f1"
Commit 539aa992 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.2' into v0.6.2-dev

parents 93872128 7193774b
......@@ -23,7 +23,6 @@
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import math
import re
from array import array
from functools import partial
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
TypedDict)
......@@ -34,11 +33,11 @@ from PIL import Image
from torch import nn
from torch.nn.init import trunc_normal_
from transformers import PretrainedConfig
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
......@@ -54,21 +53,30 @@ from vllm.model_executor.models.minicpm import MiniCPMModel
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData
from .idefics2_vision_model import Idefics2VisionTransformer
logger = init_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = {
"llm.lm_head": "lm_head",
"llm.model": "llm",
}
class MiniCPMVImageInput(TypedDict):
"""Input mapper input with auxiliary data for computing image bounds."""
image: Image.Image
# Image bounds token ids in 0-dim scaler tensor.
im_start_id: torch.Tensor
im_end_id: torch.Tensor
slice_start_id: NotRequired[torch.Tensor]
slice_end_id: NotRequired[torch.Tensor]
class MiniCPMVImagePixelInputs(TypedDict):
pixel_values: List[torch.Tensor]
"""
......@@ -93,8 +101,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
"""
MiniCPMVImageInputs = MiniCPMVImagePixelInputs
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
......@@ -239,6 +245,25 @@ class Resampler2_5(BaseResampler):
return x
def _build_image_input(ctx: InputContext,
image: Image.Image) -> MiniCPMVImageInput:
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code)
if hasattr(tokenizer, "slice_start_id"):
return MiniCPMVImageInput(
image=image,
im_start_id=torch.tensor(tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id),
slice_start_id=torch.tensor(tokenizer.slice_start_id),
slice_end_id=torch.tensor(tokenizer.slice_end_id))
else:
return MiniCPMVImageInput(image=image,
im_start_id=torch.tensor(
tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id))
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
version_float = getattr(config, "version", None)
......@@ -259,14 +284,16 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len
return SequenceData(token_ids)
return SequenceData.from_token_counts((0, seq_len))
def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int):
def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig,
num_images: int):
width = height = hf_config.image_size
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
image = _build_image_input(ctx,
image=Image.new("RGB", (width, height),
color=0))
return {"image": [image] if num_images == 1 else [image] * num_images}
def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
......@@ -275,7 +302,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
mm_data = dummy_image_for_minicpmv(hf_config, num_images)
mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images)
return seq_data, mm_data
......@@ -286,8 +313,9 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs
model_config = ctx.model_config
version = get_version_by_config(model_config.hf_config)
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_processor = cached_get_image_processor(model_config.tokenizer)
def get_placeholder(image_size: Tuple[int, int], num_image: int):
......@@ -323,6 +351,10 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
new_prompt = "".join(new_prompt_chunks)
new_token_ids = tokenizer.encode(new_prompt)
multi_modal_data["image"] = [
_build_image_input(ctx, image) for image in images
]
llm_inputs = LLMInputs(
prompt_token_ids=new_token_ids,
prompt=new_prompt,
......@@ -331,6 +363,32 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs
def input_mapper_for_minicpmv(ctx: InputContext, data: object):
model_config = ctx.model_config
image_processor = cached_get_image_processor(
model_config.model, trust_remote_code=model_config.trust_remote_code)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
if not isinstance(data, list):
raise ValueError(
"Image input must be list of MiniCPMVImageInput, got (%s)", data)
batch_data = image_processor \
.preprocess([img["image"] for img in data], return_tensors="pt") \
.data
if len(data) > 0:
batch_data["im_start_id"] = data[0]["im_start_id"]
batch_data["im_end_id"] = data[0]["im_end_id"]
if "slice_start_id" in data[0]:
batch_data["slice_start_id"] = data[0]["slice_start_id"]
batch_data["slice_end_id"] = data[0]["slice_end_id"]
return MultiModalInputs(batch_data)
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
......@@ -371,7 +429,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
def get_embedding(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
image_inputs: Optional[MiniCPMVImagePixelInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
if hasattr(self.config, "scale_emb"):
......@@ -399,14 +457,20 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
return vlm_embedding, vision_hidden_states
def _get_image_bounds(self, input_ids: torch.Tensor) -> torch.Tensor:
tokenizer = cached_get_tokenizer(self.config._name_or_path,
trust_remote_code=True)
start_cond = input_ids == tokenizer.im_start_id
end_cond = input_ids == tokenizer.im_end_id
if hasattr(tokenizer, "slice_start_id"):
start_cond |= (input_ids == tokenizer.slice_start_id)
end_cond |= (input_ids == tokenizer.slice_end_id)
def _get_image_bounds(
self,
input_ids: torch.Tensor,
im_start_id: torch.Tensor,
im_end_id: torch.Tensor,
slice_start_id: Optional[torch.Tensor] = None,
slice_end_id: Optional[torch.Tensor] = None) -> torch.Tensor:
# All the images in the batch should share the same special image
# bound token ids.
start_cond = input_ids == im_start_id[0]
end_cond = input_ids == im_end_id[0]
if slice_start_id is not None:
start_cond |= (input_ids == slice_start_id[0])
end_cond |= (input_ids == slice_end_id[0])
image_start_tokens, = torch.where(start_cond)
image_start_tokens += 1
......@@ -425,7 +489,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
self,
input_ids: torch.Tensor,
**kwargs: object,
) -> Optional[MiniCPMVImageInputs]:
) -> Optional[MiniCPMVImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", [])
tgt_sizes = kwargs.pop("tgt_sizes", [])
......@@ -462,8 +526,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
if len(pixel_values_flat) == 0:
return None
return MiniCPMVImageInputs(
image_bounds=self._get_image_bounds(input_ids),
im_start_id = kwargs.pop("im_start_id", None)
im_end_id = kwargs.pop("im_end_id", None)
slice_start_id = kwargs.pop("slice_start_id", None)
slice_end_id = kwargs.pop("slice_end_id", None)
if im_start_id is None:
return None
return MiniCPMVImagePixelInputs(
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
pixel_values=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat),
)
......@@ -570,8 +643,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
) -> torch.Tensor:
raise NotImplementedError
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
raise NotImplementedError
def is_default_weight_loading(self, name: str) -> bool:
......@@ -660,8 +733,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
res.append(self.resampler(vision_embedding, tgt_size))
return torch.vstack(res)
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
return self.get_vision_embedding(pixel_values)
......@@ -719,8 +792,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
return vision_embedding
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
tgt_sizes = data["tgt_sizes"]
......@@ -813,8 +886,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
).last_hidden_state
return vision_embedding
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
tgt_sizes = data["tgt_sizes"]
......@@ -857,7 +930,7 @@ _SUPPORT_VERSION = {
}
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_minicpmv)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv)
......@@ -884,7 +957,7 @@ class MiniCPMV(MiniCPMVBaseModel):
version = str(config.version).split(".")
version = tuple([int(x) for x in version])
# Dispatch class based on version
instance_class = _SUPPORT_VERSION.get(version, None)
instance_class = _SUPPORT_VERSION.get(version)
if instance_class is None:
raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
......
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
import math
from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers.models.mllama.configuration_mllama as config_mllama
from PIL import Image
from torch import nn
from transformers.modeling_outputs import (BaseModelOutput,
CausalLMOutputWithPast)
from transformers.models.mllama.image_processing_mllama import (
get_optimal_tiled_canvas)
import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
from .clip import CLIPMLP
from .interfaces import SupportsMultiModal
from .llama import LlamaDecoderLayer, LlamaMLP
logger = init_logger(__name__)
MLLAMA_IMAGE_TOKEN_ID = 128256
MLLAMA_IMAGE_TOKEN = "<|image|>"
class MllamaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: """
"""(batch_size, max_num_image, max_num_chunk, num_channel, height, width)"""
aspect_ratio_ids: torch.Tensor
"""Shape: `(batch_size, max_num_image)`"""
aspect_ratio_mask: torch.Tensor
"""Shape: `(batch_size, max_num_image, max_num_tiles)`"""
# TODO: support LlamaImageEmbeddingInputs
def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
# move encoder_prompt to prompt
if llm_inputs.get("prompt") is None:
llm_inputs["prompt"] = llm_inputs["encoder_prompt"]
llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"]
# process multi-modal data
assert "decoder_multi_modal_data" not in llm_inputs, \
"multi-modal data should be put in encoder message of mllama"
multi_modal_data = llm_inputs.get("encoder_multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data \
or multi_modal_data["image"] is None:
# text-only
llm_inputs["encoder_prompt"] = ""
llm_inputs["encoder_prompt_token_ids"] = []
llm_inputs["encoder_multi_modal_data"] = {}
return llm_inputs
# get num_tiles
if isinstance(multi_modal_data['image'], Image.Image):
multi_modal_data['image'] = [multi_modal_data['image']]
hf_config = ctx.model_config.hf_config
num_tiles = 0
for image in multi_modal_data["image"]:
width, height = image.size
tile_size = hf_config.vision_config.image_size
canvas_height, canvas_width = get_optimal_tiled_canvas(
image_height=height,
image_width=width,
max_image_tiles=hf_config.vision_config.max_num_tiles,
tile_size=tile_size,
)
num_tiles_height = canvas_height // tile_size
num_tiles_width = canvas_width // tile_size
num_tiles += num_tiles_height * num_tiles_width
# set encoder prompt based on num_tiles
assert hf_config.vision_config.image_size % 14 == 0, \
"chunk size should be multiple of 14"
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
num_tokens = num_tiles * token_per_chunk
llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens
llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID
] * num_tokens
return llm_inputs
def get_max_mllama_image_tokens(ctx: InputContext) -> int:
hf_config = ctx.model_config.hf_config
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
return hf_config.vision_config.max_num_tiles * token_per_chunk
def dummy_decoder_seq_data(seq_len: int, num_images: int):
# <|image|> * num_images + 0 * (seq_len - num_images)
assert seq_len >= num_images, \
"seq_len should be greater than or equal to num_images"
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[MLLAMA_IMAGE_TOKEN_ID]) * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - num_images)
return SequenceData(token_ids)
def dummy_encoder_seq_data(ctx: InputContext, num_images: int):
num_tokens = get_max_mllama_image_tokens(ctx) * num_images
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[MLLAMA_IMAGE_TOKEN_ID]) * num_tokens
return SequenceData(token_ids)
def dummy_image(num_images: int, ):
width = height = 1024
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
return dummy_decoder_seq_data(seq_len, num_images), None
def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
return dummy_encoder_seq_data(ctx, num_images), dummy_image(num_images)
def _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask: torch.Tensor,
num_patches: int,
target_length: int,
dtype: torch.dtype,
) -> torch.Tensor:
# Expand aspect ratio mask to target_length
batch_size, max_num_tiles = aspect_ratio_mask.shape
attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1,
1).to(dtype)
attention_mask = attention_mask.repeat(1, 1, target_length, 1)
# Mask padding patches
pad_patches = target_length - num_patches
attention_mask[:, :, -pad_patches:] = 0
# Invert the mask (0 -> 1, 1 -> 0)
attention_mask = 1 - attention_mask
# Reshape to 2D and create 4D attention mask
# (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length)
attention_mask = attention_mask.reshape(batch_size,
max_num_tiles * target_length, 1)
attention_mask = attention_mask @ attention_mask.transpose(
-1, -2) * torch.finfo(dtype).min
attention_mask = attention_mask.unsqueeze(1)
return attention_mask
class ColumnParallelConv2dPatch(torch.nn.Module):
"""Conv2D Patching layer with model parallelism.
Column parallel over unfolded input.
Arguments:
in_channels: Input channels.
out_channels: Output channels.
kernel_size: Size of convolution kernel.
stride (default 1): Stride for convolution.
bias (default False): Use bias in Conv2d.
Input: (bsz, in_channels, width, height)
Output: (bsz, num_tokens, out_channels)
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
bias: bool = False,
) -> None:
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
self._linear = ColumnParallelLinear(
in_channels * kernel_size[0] * kernel_size[1],
out_channels,
bias=bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._unfold(x)
x = x.permute(0, 2, 1)
x, _ = self._linear(x)
return x
class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
def __init__(self,
config: config_mllama.MllamaVisionConfig,
is_gated: bool = True):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.is_gated = is_gated
self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1,
self.max_num_tiles * self.hidden_size)
if is_gated:
self.gate = nn.Parameter(torch.zeros(1))
def forward(self, hidden_state: torch.Tensor,
aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
embeddings = self.embedding(aspect_ratio_ids)
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1,
self.hidden_size)
if self.is_gated:
embeddings = embeddings * self.gate.tanh()
hidden_state = hidden_state + embeddings
return hidden_state
class MllamaPrecomputedPositionEmbedding(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.num_patches = (config.image_size // config.patch_size)**2 + 1
self.hidden_size = config.hidden_size
self.scale = config.hidden_size**-0.5
self.gate = nn.Parameter(torch.zeros(1))
# position embedding
position_embedding = torch.randn(self.num_patches, self.hidden_size)
self.embedding = nn.Parameter(self.scale * position_embedding)
# tile position embedding
self.tile_embedding = nn.Embedding(
self.max_aspect_ratio_id + 1,
self.max_num_tiles * self.num_patches * self.hidden_size)
def forward(self, hidden_state: torch.Tensor,
aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
# position embeddings
gated_position_embedding = (1 - self.gate.tanh()) * self.embedding
hidden_state = hidden_state + gated_position_embedding.view(
1, 1, self.num_patches, self.hidden_size)
# precomputed tile position embeddings
tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
batch_size = hidden_state.shape[0]
tile_position_embedding = tile_position_embedding.reshape(
batch_size, self.max_num_tiles, self.num_patches, self.hidden_size)
gated_tile_position_embedding = self.gate.tanh(
) * tile_position_embedding
hidden_state = hidden_state + gated_tile_position_embedding
return hidden_state
# TODO: support other attention backends for attention in vision model
class MllamaVisionSdpaAttention(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig):
super().__init__()
model_parallel_size = get_tensor_model_parallel_world_size()
self.embed_dim = config.hidden_size
self.num_heads = config.attention_heads
self.head_dim = config.hidden_size // config.attention_heads
self.num_local_heads = self.num_heads // model_parallel_size
self.q_size = self.num_local_heads * self.head_dim
self.kv_size = self.num_local_heads * self.head_dim
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=False,
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=False,
input_is_parallel=True,
)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_state)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(q.shape[0], q.shape[1], self.num_local_heads,
self.head_dim).transpose(1, 2)
k = k.view(k.shape[0], k.shape[1], self.num_local_heads,
self.head_dim).transpose(1, 2)
v = v.view(v.shape[0], v.shape[1], self.num_local_heads,
self.head_dim).transpose(1, 2)
# TODO: remove padding in image encoder
attn_output = F.scaled_dot_product_attention(q,
k,
v,
attn_mask=attention_mask,
dropout_p=0.0)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(attn_output.shape[0],
attn_output.shape[1], -1)
output, _ = self.o_proj(attn_output)
return output
class MllamaVisionEncoderLayer(nn.Module):
def __init__(self,
config: config_mllama.MllamaVisionConfig,
is_gated: bool = False):
super().__init__()
self.hidden_size = config.hidden_size
self.num_attention_heads = config.attention_heads
self.is_gated = is_gated
self.intermediate_size = config.intermediate_size
self.self_attn = MllamaVisionSdpaAttention(config)
self.mlp = CLIPMLP(config)
self.input_layernorm = nn.LayerNorm(self.hidden_size,
eps=config.norm_eps)
self.post_attention_layernorm = nn.LayerNorm(self.hidden_size,
eps=config.norm_eps)
# there used to be an if else here, no code path
if is_gated:
self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4)
self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
# Self Attention
residual = hidden_state
hidden_state = self.input_layernorm(hidden_state)
hidden_state = self.self_attn(hidden_state,
attention_mask=attention_mask)
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
hidden_state = residual + gate_attn * hidden_state
# Feed forward
residual = hidden_state
hidden_state = self.post_attention_layernorm(hidden_state)
hidden_state = self.mlp(hidden_state)
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
hidden_state = residual + gate_ffn * hidden_state
return hidden_state
class MllamaVisionEncoder(nn.Module):
def __init__(self,
config: config_mllama.MllamaVisionConfig,
num_layers=32,
is_gated=False,
output_hidden_states=None):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
MllamaVisionEncoderLayer(config, is_gated)
for _ in range(num_layers)
])
self.output_hidden_states = output_hidden_states or []
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutput]:
encoder_states = ()
for i, encoder_layer in enumerate(self.layers):
if i in self.output_hidden_states:
encoder_states = encoder_states + (hidden_states, )
hidden_states = encoder_layer(
hidden_states,
attention_mask,
)
if len(self.layers) - 1 in self.output_hidden_states:
encoder_states = encoder_states + (hidden_states, )
return hidden_states, encoder_states
class MllamaVisionModel(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig):
super().__init__()
self.image_size = config.image_size
self.patch_size = config.patch_size
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.in_channels = config.num_channels
self.intermediate_layers_indices = config.intermediate_layers_indices
self.num_patches = (self.image_size // self.patch_size)**2 + 1
self.scale = config.hidden_size**-0.5
self.patch_embedding = ColumnParallelConv2dPatch(
in_channels=config.num_channels,
out_channels=self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.class_embedding = nn.Parameter(self.scale *
torch.randn(self.hidden_size))
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
config)
self.pre_tile_positional_embedding = \
MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True)
self.post_tile_positional_embedding = \
MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True)
# layer norms
self.layernorm_pre = nn.LayerNorm(self.hidden_size)
self.layernorm_post = nn.LayerNorm(self.hidden_size)
# encoders
self.transformer = MllamaVisionEncoder(
config,
config.num_hidden_layers,
is_gated=False,
output_hidden_states=config.intermediate_layers_indices)
self.global_transformer = MllamaVisionEncoder(config,
config.num_global_layers,
is_gated=True)
def apply_class_embedding(self,
hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1,
hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def forward(self, pixel_values: torch.Tensor,
aspect_ratio_ids: torch.Tensor,
aspect_ratio_mask: torch.Tensor) -> torch.Tensor:
batch_size, num_concurrent_media, num_tiles, num_channels, \
height, width = pixel_values.shape
pixel_values = pixel_values.reshape(
batch_size * num_concurrent_media * num_tiles, num_channels,
height, width)
aspect_ratio_ids = aspect_ratio_ids.reshape(
batch_size * num_concurrent_media, -1)
# patch embedding
patch_embeds = self.patch_embedding(
pixel_values.to(self.layernorm_pre.weight.dtype))
hidden_state = patch_embeds
hidden_state = ps.get_tp_group().all_gather(hidden_state)
# tile embeddings
_, num_patches, dim = hidden_state.shape
hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
num_tiles, -1, dim)
hidden_state = self.pre_tile_positional_embedding(
hidden_state, aspect_ratio_ids)
# apply cls token
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media * num_tiles, num_patches, dim)
hidden_state = self.apply_class_embedding(hidden_state)
num_patches += 1
# apply position embeddings
hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
num_tiles, num_patches, dim)
hidden_state = self.gated_positional_embedding(hidden_state,
aspect_ratio_ids)
# apply encoder
hidden_state = self.layernorm_pre(hidden_state)
# Compute the number of tokens to pad
num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
# Compute padding tuple for pad function
padding = (
0, 0, 0, num_padding_patches
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
# Pad the tensor
hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
slice_index = -num_padding_patches if num_padding_patches > 0 else None
attention_mask = aspect_ratio_mask.reshape(
batch_size * num_concurrent_media, -1)
attention_mask = _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask=attention_mask,
num_patches=self.num_patches,
target_length=hidden_state.shape[2],
dtype=self.layernorm_pre.weight.dtype,
)
hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1,
dim)
output = self.transformer(
hidden_state,
attention_mask=attention_mask,
)
hidden_state, intermediate_hidden_states = output[0], output[1]
intermediate_hidden_states = torch.stack(intermediate_hidden_states,
dim=-1)
# apply global encoder
hidden_state = self.layernorm_post(hidden_state)
hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim)
hidden_state = self.post_tile_positional_embedding(
hidden_state, aspect_ratio_ids)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles * (num_patches + num_padding_patches), dim)
hidden_state = self.global_transformer(
hidden_state, attention_mask=attention_mask)[0]
hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim)
hidden_state = hidden_state[:, :, :slice_index]
# adding intermediate layer outputs
hidden_state = hidden_state.reshape(batch_size, num_concurrent_media,
num_tiles, num_patches, dim)
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size * num_concurrent_media, num_tiles,
num_patches + num_padding_patches, -1)
intermediate_hidden_states = intermediate_hidden_states[:, :, :
slice_index]
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size, num_concurrent_media, num_tiles, num_patches, -1)
hidden_state = torch.cat([hidden_state, intermediate_hidden_states],
dim=-1)
return hidden_state
class MllamaTextRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
MllamaTextRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class MllamaTextCrossAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: Optional[config_mllama.MllamaTextConfig] = None,
layer_idx: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.model_parallel_size = get_tensor_model_parallel_world_size()
self.num_heads = self.config.num_attention_heads
self.num_local_heads = self.num_heads // self.model_parallel_size
self.num_key_value_heads = self.config.num_key_value_heads
self.num_local_key_value_heads = \
self.num_key_value_heads // self.model_parallel_size
self.dropout = config.dropout
self.hidden_size = config.hidden_size
self.head_dim = config.hidden_size // self.num_heads
self.layer_idx = layer_idx
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.q_local_size = self.num_local_heads * self.head_dim
self.kv_local_size = self.num_local_key_value_heads * self.head_dim
# TODO: change to Q/KV separate linear after #7448 is merged
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.num_heads,
self.num_key_value_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
)
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# use huggingface's instead
self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.scaling = self.head_dim**-0.5
self.attn = Attention(
self.num_local_heads,
self.head_dim,
self.scaling,
self.num_local_key_value_heads,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
cross_attention_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv_dec, _ = self.qkv_proj(hidden_states)
q, _, _ = qkv_dec.split(
[self.q_local_size, self.kv_local_size, self.kv_local_size],
dim=-1)
if cross_attention_states is None:
k = None
v = None
else:
qkv_enc, _ = self.qkv_proj(cross_attention_states)
_, k, v = qkv_enc.split(
[self.q_local_size, self.kv_local_size, self.kv_local_size],
dim=-1)
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
k = self.k_norm(k)
q = q.view(-1, self.num_local_heads, self.head_dim)
q = self.q_norm(q)
output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_DECODER)
out, _ = self.o_proj(output)
return out
class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention
and feedforward."""
def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int,
quant_config: Optional[QuantizationConfig]) \
-> None:
super().__init__()
self.layer_idx = layer_idx
self.cross_attn = MllamaTextCrossAttention(
config=config,
layer_idx=layer_idx,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1))
self.mlp = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1))
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor,
cross_attention_mask: torch.Tensor,
full_text_row_masked_out_mask: torch.Tensor,
kv_cache: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.cross_attn(
hidden_states=hidden_states,
attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = full_text_row_masked_out_mask * hidden_states
hidden_states = residual + self.cross_attn_attn_gate.tanh(
) * hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = full_text_row_masked_out_mask * hidden_states
hidden_states = residual + self.cross_attn_mlp_gate.tanh(
) * hidden_states
return hidden_states
class MllamaTextModel(nn.Module):
config_class = config_mllama.MllamaTextConfig
base_model_prefix = "model"
def __init__(self, config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig]):
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8,
config.hidden_size)
self.cross_attention_layers = config.cross_attention_layers
layers = []
for layer_idx in range(config.num_hidden_layers):
if layer_idx in self.cross_attention_layers:
layers.append(
MllamaCrossAttentionDecoderLayer(
config, layer_idx, quant_config=quant_config))
else:
# TODO: force LlamaDecoderLayer to config.attention_bias=False
layers.append(
LlamaDecoderLayer(config,
cache_config=cache_config,
quant_config=quant_config))
self.layers = nn.ModuleList(layers)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.LongTensor,
positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
skip_cross_attention: bool,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
for idx, decoder_layer in enumerate(self.layers):
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
if not skip_cross_attention:
hidden_states = decoder_layer(
hidden_states=hidden_states,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=
full_text_row_masked_out_mask,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
elif isinstance(decoder_layer, LlamaDecoderLayer):
hidden_states, residual = decoder_layer(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
residual=None,
)
hidden_states = hidden_states + residual
else:
raise ValueError(
f"Unknown decoder layer type {type(decoder_layer)}")
hidden_states = self.norm(hidden_states)
return hidden_states
class MllamaForCausalLM(nn.Module):
config_class = config_mllama.MllamaTextConfig
base_model_prefix = "language_model"
_no_split_modules = [
"MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"
]
def __init__(self, config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig]):
super().__init__()
self.vocab_size = config.vocab_size
self.model = MllamaTextModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
)
def forward(
self,
input_ids: torch.LongTensor,
positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
skip_cross_attention: bool,
) -> torch.Tensor:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
skip_cross_attention=skip_cross_attention,
)
return hidden_states
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_decoder_data_for_mllama)
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: config_mllama.MllamaConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size
self.max_num_tiles = config.vision_config.max_num_tiles
self.vision_output_dim = config.vision_config.vision_output_dim
self.pad_token_id = \
config.pad_token_id if config.pad_token_id is not None else -1
self.image_size = config.vision_config.image_size
self.vision_model = MllamaVisionModel(config.vision_config)
self.language_model = MllamaForCausalLM(
config.text_config,
cache_config=cache_config,
quant_config=quant_config,
)
self.multi_modal_projector = nn.Linear(
config.vision_config.vision_output_dim,
config.text_config.hidden_size,
bias=True,
)
self.logits_processor = LogitsProcessor(config.output_hidden_states,
config.text_config.vocab_size)
self.sampler = Sampler()
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.language_model.lm_head,
hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def _parse_and_validate_image_input(self, **kwargs: object):
# tensor with the same shape will be batched together by
# MultiModalInputs.batch, so pixel_values here can be:
# - List[List[torch.Tensor]]:
# with shape (num_tiles, 3, image_res, image_res)
# - List[torch.Tensor]:
# with shape (num_image, num_tiles, 3, image_res, image_res)
# - torch.Tensor:
# with shape (bs, num_image, num_tiles, 3, image_res, image_res)
pixel_values: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"pixel_values", None)
image_embeds: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"image_embeds", None)
aspect_ratio_ids: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"aspect_ratio_ids", None)
aspect_ratio_mask: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"aspect_ratio_mask", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None and image_embeds is not None:
raise ValueError(
"Both pixel values and image embeds are provided.")
if pixel_values is not None:
assert aspect_ratio_ids is not None
assert aspect_ratio_mask is not None
max_num_images = max([len(x[0]) for x in pixel_values])
if max_num_images == 0:
raise ValueError("No images provided.")
max_num_tiles = max(
max([len(x) for x in y[0]]) for y in pixel_values)
device = self.multi_modal_projector.weight.device
bsz = len(pixel_values)
out_num_tiles = []
out_images = torch.zeros(
bsz,
max_num_images,
max_num_tiles,
3,
self.image_size,
self.image_size,
dtype=torch.float32,
device=device,
)
out_ar_ids = torch.ones(bsz,
max_num_images,
dtype=torch.int64,
device=device)
out_ar_mask = torch.zeros(bsz,
max_num_images,
max_num_tiles,
dtype=torch.int64,
device=device)
for b in range(len(pixel_values)):
_num_tiles = []
for i in range(len(pixel_values[b][0])):
img = pixel_values[b][0][i]
out_images[b, i, :img.shape[0]] = img
out_ar_ids[b, i] = aspect_ratio_ids[b][0][i]
out_ar_mask[b, i] = aspect_ratio_mask[b][0][i]
_num_tiles.append(img.shape[0])
out_num_tiles.append(_num_tiles)
return MllamaImagePixelInputs(
type="pixel_values",
data=out_images,
aspect_ratio_ids=out_ar_ids,
aspect_ratio_mask=out_ar_mask,
)
if image_embeds is not None:
raise NotImplementedError
raise AssertionError("This line should be unreachable.")
def flat_encoder_result(self, cross_attention_states: torch.Tensor,
attn_metadata: AttentionMetadata):
cross_attention_states_flat = torch.zeros(
sum(attn_metadata.encoder_seq_lens),
cross_attention_states.shape[-1],
device=cross_attention_states.device,
dtype=cross_attention_states.dtype)
start_pos = 0
for seq_len, vision_token_in_batch in zip(
attn_metadata.encoder_seq_lens, cross_attention_states):
end_pos = start_pos + seq_len
cross_attention_states_flat[
start_pos:end_pos] = vision_token_in_batch[:seq_len]
start_pos = end_pos
cross_attention_states = cross_attention_states_flat
full_text_row_masked_out_mask = torch.ones(
(attn_metadata.num_prefill_tokens, 1), dtype=torch.bool)
start_pos = 0
for seq_len, encoder_seq_len in zip(
attn_metadata.seq_lens_tensor.cpu(),
attn_metadata.encoder_seq_lens):
if encoder_seq_len == 0:
full_text_row_masked_out_mask[start_pos:start_pos +
seq_len] = False
start_pos += seq_len
full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
cross_attention_states.device)
return cross_attention_states, full_text_row_masked_out_mask
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs: object,
) -> Union[Tuple, CausalLMOutputWithPast]:
if attn_metadata.num_prefill_tokens > 0 and \
attn_metadata.num_decode_tokens > 0:
raise ValueError("Chunk prefill not supported")
image_inputs = self._parse_and_validate_image_input(**kwargs)
if image_inputs is None:
cross_attention_mask = None
full_text_row_masked_out_mask = (
attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to(
input_ids.device)
cross_attention_states = None
skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0
else:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values = image_inputs['data']
aspect_ratio_ids = image_inputs['aspect_ratio_ids']
aspect_ratio_mask = image_inputs['aspect_ratio_mask']
cross_attention_states = self.vision_model(pixel_values,
aspect_ratio_ids,
aspect_ratio_mask)
cross_attention_states = self.multi_modal_projector(
cross_attention_states)
bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
cross_attention_states = cross_attention_states.view(
bsz, -1, image_token_dim)
cross_attention_states, full_text_row_masked_out_mask = \
self.flat_encoder_result(cross_attention_states, attn_metadata)
skip_cross_attention = False
# TODO: support multi-image by this mask
cross_attention_mask = None
outputs = self.language_model(
input_ids=input_ids,
positions=positions,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
skip_cross_attention=skip_cross_attention,
)
return outputs
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
updated_params = set()
for name, loaded_weight in weights:
if 'patch_embedding.weight' in name:
name = name.replace('patch_embedding.weight',
'patch_embedding._linear.weight')
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
updated_params.add(name)
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict.pop(name)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once
class OlmoeMoE(nn.Module):
"""A tensor-parallel MoE implementation for Olmoe that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = ""):
super().__init__()
self.hidden_size = hidden_size
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(hidden_size,
num_experts,
bias=False,
quant_config=None)
self.experts = FusedMoE(num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
reduce_results=True,
renormalize=False,
quant_config=quant_config,
tp_size=tp_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
return final_hidden_states.view(orig_shape)
class OlmoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 4096,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.q_norm = RMSNorm(hidden_size, eps=1e-5)
self.k_norm = RMSNorm(hidden_size, eps=1e-5)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=True,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class OlmoeDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
4096)
self.self_attn = OlmoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
self.mlp = OlmoeMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class OlmoeModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
OlmoeDecoderLayer(config,
layer_idx,
cache_config,
quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], attn_metadata,
residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class OlmoeForCausalLM(nn.Module):
fall_back_to_pt_during_load = False
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = OlmoeModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
......@@ -23,7 +22,7 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import filter_weights, merge_multimodal_embeddings
from .utils import group_weights_with_prefix, merge_multimodal_embeddings
logger = init_logger(__name__)
......@@ -153,7 +152,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.unpadded_vocab_size = config.text_config.vocab_size
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
config.text_config.vocab_size,
logit_scale)
self.sampler = Sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
......@@ -286,21 +286,18 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
weights_group = group_weights_with_prefix(weights)
# load vision tower
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])
......@@ -213,10 +213,10 @@ class PersimmonModel(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.vocab_size = config.vocab_size
self.vocab_size = config.text_config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.embed_tokens = VocabParallelEmbedding(
config.text_config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([
PersimmonDecoderLayer(config,
cache_config=cache_config,
......@@ -257,14 +257,14 @@ class PersimmonForCausalLM(nn.Module):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.vocab_size = config.text_config.vocab_size
self.model = PersimmonModel(config,
cache_config=cache_config,
quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size,
self.lm_head = ParallelLMHead(config.text_config.vocab_size,
config.hidden_size,
bias=False)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
self.sampler = Sampler()
def forward(
......
# coding=utf-8
# Adapted from llama.py
"""Inference-only Phi3 model code inherit from Llama.py"""
from vllm.model_executor.models.llama import LlamaForCausalLM
class Phi3ForCausalLM(LlamaForCausalLM):
packed_modules_mapping = {
"qkv_proj": [
"qkv_proj",
],
"gate_up_proj": [
"gate_up_proj",
],
}
......@@ -307,7 +307,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int):
transposed = False
if width < height:
width, height = height, width
......@@ -337,8 +337,10 @@ def get_phi3v_image_feature_size(
*,
input_height: int,
input_width: int,
num_crops: int,
) -> int:
num_crops = hf_config.get("num_crops", 16)
if num_crops is None:
num_crops = hf_config.get("num_crops", 16)
new_width, new_height = _calc_hd_transform_size(width=input_width,
height=input_height,
hd_num=num_crops)
......@@ -347,20 +349,26 @@ def get_phi3v_image_feature_size(
+ (new_height // 336 + 1) * 12
def get_max_phi3v_image_tokens(ctx: InputContext):
def get_max_phi3v_image_tokens(ctx: InputContext,
*,
num_crops: Optional[int] = None):
return get_phi3v_image_feature_size(
ctx.get_hf_image_processor_config(),
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
num_crops=num_crops,
)
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
def dummy_data_for_phi3v(ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
num_crops: Optional[int] = None):
num_images = mm_counts["image"]
image_feature_size = get_max_phi3v_image_tokens(ctx)
image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops)
seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
......@@ -398,7 +406,10 @@ def _get_image_placeholder_token_ids(model_config: ModelConfig,
return image_placeholder_token_ids
def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
def input_processor_for_phi3v(ctx: InputContext,
llm_inputs: LLMInputs,
*,
num_crops: Optional[int] = None):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
......@@ -412,7 +423,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
image_feature_size = [
get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h)
input_height=h,
num_crops=num_crops)
]
image_data = [image_data]
elif is_list_of(image_data, Image.Image):
......@@ -422,7 +434,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
image_feature_size.append(
get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h))
input_height=h,
num_crops=num_crops))
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
......
......@@ -321,13 +321,13 @@ class PhiMoEAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
quant_config=None,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=True,
quant_config=None,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
......@@ -491,6 +491,10 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
"o_proj",
"embed_tokens",
"lm_head",
"w1",
"w2",
"w3",
"gate",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
......
from array import array
from dataclasses import dataclass, fields
from itertools import tee
from typing import Iterable, List, Mapping, Optional, Tuple, Union
......@@ -24,8 +23,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData
from .interfaces import SupportsMultiModal
from .utils import init_vllm_registered_model
......@@ -63,13 +61,11 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
image_feature_size = (size**2) // (patch_size**2)
num_image_tokens = image_feature_size * num_images
seq_data = SequenceData.from_token_counts(
(image_token_id, num_image_tokens),
(0, seq_len - num_image_tokens),
)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * num_image_tokens
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - num_image_tokens)
seq_data = SequenceData(token_ids)
mm_data = {"image": num_images * [image]}
return seq_data, mm_data
......@@ -454,7 +450,7 @@ class Transformer(nn.Module):
return x
def position_meshgrid(patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor:
def position_meshgrid(patch_embeds_list: List[torch.Tensor], ) -> torch.Tensor:
positions = torch.cat([
torch.stack(
torch.meshgrid(
......
......@@ -7,7 +7,6 @@
import math
import re
from array import array
from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Tuple, TypedDict, Union)
......@@ -48,8 +47,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import is_list_of
from .utils import flatten_bn, is_pp_missing_parameter, make_layers
......@@ -689,8 +687,9 @@ def input_processor_for_qwen(ctx: InputContext,
prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_data = multi_modal_data["image"]
if isinstance(image_data, torch.Tensor):
num_dims = len(image_data.shape)
......@@ -750,8 +749,9 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
return MultiModalInputs()
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_pair_tok = tokenizer.encode(IMG_START + IMG_END,
add_special_tokens=False,
......@@ -832,15 +832,16 @@ def dummy_data_for_qwen(
# The presence of a visual config indicates this is a multimodal model.
# If we don't have it, the model is considered an LLM for warmup purposes.
if not hasattr(hf_config, "visual"):
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len))
seq_data = SequenceData.from_token_counts((0, seq_len))
mm_data = None
return seq_data, mm_data
# We have a visual component - use images to warm up
num_images = mm_counts["image"]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
# Build the image prompts with no imgpads; the tokenizer will add img pads
image_prompt = ''.join(
......@@ -859,11 +860,13 @@ def dummy_data_for_qwen(
if len(toks) < seq_len:
toks += [0] * (seq_len - len(toks))
seq_data = SequenceData.from_seqs(toks)
# Build the input images; width/height doesn't actually matter here since
# the data will get resized and the # of tokens per image is constant
image = Image.new("RGB", (224, 224), color=0)
mm_data = {"image": image if num_images == 1 else [image] * num_images}
return SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, toks)), mm_data
return seq_data, mm_data
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
......
......@@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
......@@ -247,11 +247,16 @@ class Qwen2Model(nn.Module):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Qwen2DecoderLayer(config=config,
......@@ -260,7 +265,10 @@ class Qwen2Model(nn.Module):
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
......
......@@ -22,7 +22,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from array import array
from functools import lru_cache, partial
from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict,
Union)
......@@ -46,7 +45,7 @@ from vllm.attention import AttentionMetadata
from vllm.attention.selector import (_Backend, backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import parallel_state
from vllm.distributed import get_pp_group, parallel_state
from vllm.distributed import utils as dist_utils
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
......@@ -66,9 +65,12 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
from vllm.multimodal.base import MultiModalData
from vllm.multimodal.image import cached_get_image_processor
from vllm.platforms import current_platform
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import get_processor
from vllm.utils import is_cpu
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory)
logger = init_logger(__name__)
......@@ -207,7 +209,7 @@ class Qwen2VisionAttention(nn.Module):
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.get_device_capability()[0] >= 8
device_available = current_platform.has_device_capability(80)
if device_available:
from transformers.utils import is_flash_attn_2_available
......@@ -280,6 +282,21 @@ class Qwen2VisionAttention(nn.Module):
context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
elif is_cpu():
seq_length = q.size(1)
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]]
attention_mask = torch.zeros([1, seq_length, seq_length],
device=q.device,
dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
cu_seqlens[i - 1]:cu_seqlens[i]] = True
output = F.scaled_dot_product_attention(q,
k,
v,
attention_mask,
dropout_p=0.0)
context_layer = rearrange(output, "b h s d -> b s h d ")
else:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
......@@ -681,15 +698,14 @@ def dummy_data_for_qwen2_vl(
"--limit-mm-per-prompt.")
hf_config = ctx.get_hf_config(Qwen2VLConfig)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[hf_config.vision_start_token_id])
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[hf_config.image_token_id]) * max_llm_image_tokens
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[hf_config.vision_end_token_id])
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - max_llm_image_tokens - 2)
dummy_seqdata = SequenceData(token_ids)
dummy_seqdata = SequenceData.from_token_counts(
(hf_config.vision_start_token_id, 1),
(hf_config.image_token_id, max_llm_image_tokens),
(hf_config.vision_end_token_id, 1),
(0, seq_len - max_llm_image_tokens - 2),
)
dummy_image = Image.new("RGB", (max_resized_width, max_resized_height),
color=0)
......@@ -859,15 +875,21 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
self.model = Qwen2Model(config, cache_config, quant_config)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def _validate_and_reshape_mm_tensor(self,
mm_input: Union[torch.Tensor,
......@@ -982,7 +1004,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)
if image_input is None and video_input is None:
if (image_input is None
and video_input is None) or not get_pp_group().is_first_rank:
inputs_embeds = None
else:
if getattr(self.config, "rope_scaling", {}).get("type",
......@@ -1018,6 +1041,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
......@@ -1058,6 +1082,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -1084,6 +1110,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
except KeyError:
print(params_dict.keys())
......
......@@ -2,9 +2,9 @@
within a vision language model."""
import math
from array import array
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
from PIL import Image
from torch import nn
......@@ -24,7 +24,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
from vllm.sequence import SequenceData
try:
from xformers import ops as xops
......@@ -67,11 +67,10 @@ def dummy_seq_data_for_siglip(
else:
image_feature_size = image_feature_size_override
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size)
return SequenceData(token_ids)
return SequenceData.from_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
def dummy_image_for_siglip(
......@@ -91,6 +90,24 @@ def dummy_image_for_siglip(
return {"image": image if num_images == 1 else [image] * num_images}
def dummy_video_for_siglip(
hf_config: SiglipVisionConfig,
num_frames: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
pil_frame = dummy_image_for_siglip(
hf_config,
num_images=1,
image_width_override=image_width_override,
image_height_override=image_height_override)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
mm_data = {"video": mm_data_per_video}
return mm_data
def input_processor_for_siglip(
model_config: ModelConfig,
hf_config: SiglipVisionConfig,
......@@ -503,6 +520,7 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override: Optional[int] = None,
):
super().__init__()
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
......@@ -513,10 +531,6 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override=num_hidden_layers_override,
)
@property
def _require_post_layernorm(self) -> bool:
return self.vision_model.post_layernorm is not None
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
......@@ -542,12 +556,12 @@ class SiglipVisionModel(nn.Module):
for name, loaded_weight in weights:
# post_layernorm is optional in SiglipVisionModel
if ("vision_model.post_layernorm" in name
and not self._require_post_layernorm):
if (name.startswith("vision_model.post_layernorm")
and self.vision_model.post_layernorm is None):
continue
# omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name:
if name.startswith("vision_model.encoder.layers"):
layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count:
continue
......
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Solar model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.model_executor.models.utils import (PPMissingLayer,
is_pp_missing_parameter,
make_layers)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip
class SolarMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class SolarAttention(nn.Module):
def __init__(
self,
config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class SolarDecoderLayer(nn.Module):
def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] \
= config.original_max_position_embeddings
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
self.self_attn = SolarAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = SolarMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class SolarModel(nn.Module):
def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: SolarDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
bskcn_h_1 = None
bskcn_h_2 = None
bskcn_r_1 = None
bskcn_r_2 = None
bskcn_tv = (self.config.bskcn_tv[0]
if self.training else self.config.bskcn_tv[1])
for i in range(self.start_layer, self.end_layer):
if i in self.config.bskcn_1:
bskcn_h_1 = hidden_states.clone()
bskcn_r_1 = residual.clone()
if i in self.config.bskcn_2:
bskcn_h_2 = hidden_states.clone()
bskcn_r_2 = residual.clone()
if i in self.config.bskcn_3:
hidden_states = bskcn_h_1 * bskcn_tv + hidden_states * (
1 - bskcn_tv)
residual = bskcn_r_1 * bskcn_tv + residual * (1 - bskcn_tv)
if i in self.config.bskcn_4:
hidden_states = bskcn_h_2 * bskcn_tv + hidden_states * (
1 - bskcn_tv)
residual = bskcn_r_2 * bskcn_tv + residual * (1 - bskcn_tv)
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class SolarForCausalLM(nn.Module, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.lora_config = lora_config
self.model = SolarModel(
config,
cache_config,
quant_config,
lora_config=lora_config,
prefix="model",
)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = Sampler()
else:
self.lm_head = PPMissingLayer()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return model_output
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros(
(batch_size, self.config.hidden_size),
dtype=dtype,
device=device,
),
"residual":
torch.zeros(
(batch_size, self.config.hidden_size),
dtype=dtype,
device=device,
),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn
if is_hip():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
import itertools
import math
from array import array
from functools import lru_cache
......@@ -21,15 +20,16 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.data import LLMInputs
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (filter_weights, flatten_bn,
from vllm.model_executor.models.utils import (flatten_bn,
group_weights_with_prefix,
init_vllm_registered_model,
merge_multimodal_embeddings)
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -43,8 +43,6 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25
logger = init_logger(__name__)
class UltravoxAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
......@@ -77,15 +75,11 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
def dummy_data_for_ultravox(
def dummy_seq_data_for_ultravox(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
audio_count: int,
):
feature_extractor = whisper_feature_extractor(ctx)
audio_count = mm_counts["audio"]
audio_placeholder = array(
VLLM_TOKEN_ID_ARRAY_TYPE,
[_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx)
......@@ -96,10 +90,28 @@ def dummy_data_for_ultravox(
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - len(audio_token_ids))
return SequenceData(audio_token_ids + other_token_ids)
def dummy_audio_for_ultravox(
ctx: InputContext,
audio_count: int,
):
feature_extractor = whisper_feature_extractor(ctx)
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
mm_dict = {"audio": [audio_and_sr] * audio_count}
return {"audio": [audio_and_sr] * audio_count}
def dummy_data_for_ultravox(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
):
audio_count = mm_counts["audio"]
seq_data = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
mm_dict = dummy_audio_for_ultravox(ctx, audio_count)
return (SequenceData(audio_token_ids + other_token_ids), mm_dict)
return (seq_data, mm_dict)
def input_mapper_for_ultravox(ctx: InputContext, data: object):
......@@ -323,14 +335,23 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
self.multi_modal_config = multimodal_config
assert self.multi_modal_config
self.secondary_weights = []
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
if config.audio_model_id is not None:
self.audio_tower = ModifiedWhisperEncoder.from_pretrained(
config.audio_model_id)
else:
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
self.secondary_weights.append(
DefaultModelLoader.Source(
model_or_path=config.audio_model_id,
revision=None,
prefix="audio_tower.",
))
self.multi_modal_projector = UltravoxProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
if config.text_model_id is not None:
self.secondary_weights.append(
DefaultModelLoader.Source(model_or_path=config.text_model_id,
revision=None,
prefix="language_model."))
def _audio_features_to_embeddings(
self, input_features: torch.Tensor) -> torch.Tensor:
......@@ -453,11 +474,22 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
projector_weights, llm_weights = itertools.tee(weights, 2)
weights_group = group_weights_with_prefix(weights)
# load audio tower weights
audio_tower_weights = weights_group["audio_tower"]
audio_tower_params_dict = dict(
self.audio_tower.named_parameters(
prefix=self.audio_tower.base_model_prefix))
for name, loaded_weight in audio_tower_weights:
if name in audio_tower_params_dict:
param = audio_tower_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load projector weights
projector_weights = filter_weights(projector_weights,
"multi_modal_projector")
projector_weights = weights_group["multi_modal_projector"]
projector_params_dict = dict(
self.multi_modal_projector.named_parameters())
for name, loaded_weight in projector_weights:
......@@ -467,5 +499,4 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
weight_loader(param, loaded_weight)
# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])
import itertools
from collections import UserDict
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
Union, overload)
......@@ -16,7 +18,23 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
class WeightsGroup(UserDict):
"""
Wraps grouped weights dictionary for a more informative error message
when attempting to access a weight component that does not exist.
"""
def __getitem__(self, key: str) -> int:
try:
return super().__getitem__(key)
except KeyError as exc:
msg = (f"There is no weights named with the prefix: {key}. "
f"Available prefix: {set(self.keys())}")
raise KeyError(msg) from exc
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
prefix: str) -> Iterable[Tuple[str, torch.Tensor]]:
"""
Helper function to load weights for inner vLLM models.
......@@ -30,6 +48,22 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
yield name, loaded_weight
def group_weights_with_prefix(
weights: Iterable[Tuple[str, torch.Tensor]]
) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]:
"""
Helper function to group weights with prefix
"""
init_weights, repeated_weights = itertools.tee(weights, 2)
weights_prefix = {name.split(".")[0] for name, _ in init_weights}
repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))
return WeightsGroup({
prefix: filter_weights(component, prefix)
for component, prefix in zip(repeated_weights, weights_prefix)
})
def init_vllm_registered_model(
hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],
......
......@@ -328,6 +328,64 @@ class PackedvLLMParameter(ModelWeightParameter):
marlin_tile_size=self.marlin_tile_size)
def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
output_dim: int, **kwargs) -> BasevLLMParameter:
"""
Permute a parameter's layout to the specified input and output dimensions,
useful for forcing the parameter into a known layout, for example, if I need
a packed (quantized) weight matrix to be in the layout
{input_dim = 0, output_dim = 1, packed_dim = 0}
then I can call:
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
to ensure x is in the correct layout (permuting it to the correct layout if
required, asserting if it cannot get it to the correct layout)
"""
curr_input_dim = getattr(param, "input_dim", None)
curr_output_dim = getattr(param, "output_dim", None)
if curr_input_dim is None or curr_output_dim is None:
assert param.data.dim() == 2,\
"permute_param_layout_ only supports 2D parameters when either "\
"input_dim or output_dim is not set"
# if one of the dimensions is not set, set it to the opposite of the other
# we can only do this since we asserted the parameter is 2D above
if curr_input_dim is None:
assert curr_output_dim is not None,\
"either input or output dim must be set"
curr_input_dim = (curr_output_dim + 1) % 2
if curr_output_dim is None:
assert curr_input_dim is not None,\
"either input or output dim must be set"
curr_output_dim = (curr_input_dim + 1) % 2
# create permutation from the current layout to the layout with
# self.input_dim at input_dim and self.output_dim at output_dim preserving
# other dimensions
perm = [
i for i in range(param.data.dim())
if i not in [curr_input_dim, curr_output_dim]
]
perm.insert(input_dim, curr_input_dim)
perm.insert(output_dim, curr_output_dim)
if "packed_dim" in kwargs:
assert hasattr(param, "packed_dim") and\
param.packed_dim == perm[kwargs["packed_dim"]],\
"permute_param_layout_ currently doesn't support repacking"
param.data = param.data.permute(*perm)
if hasattr(param, "_input_dim"):
param._input_dim = input_dim
if hasattr(param, "_output_dim"):
param._output_dim = output_dim
if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
param._packed_dim = kwargs["packed_dim"]
return param
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
marlin_tile_size):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
......
import random
from array import array
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
......@@ -8,15 +7,10 @@ import torch
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
SequenceGroupMetadata)
from vllm.triton_utils.sample import get_num_triton_sampler_splits
from vllm.utils import (PyObjectCache, async_tensor_h2d,
is_pin_memory_available, make_tensor_with_pad,
maybe_expand_dim)
is_pin_memory_available, make_tensor_with_pad)
_SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558
# Some triton sampler related code is guarded before it is ready.
_USE_TRITON_SAMPLER = False
@dataclass
......@@ -74,12 +68,12 @@ def gen_seq_group_to_sample_builder(num_seqs: int):
generator=None,
is_prompt=True,
prompt_logprob_indices=[],
sample_indices=[])
sample_indices=[],
)
class SamplingMetadataCache:
"""Used to cache SamplingMetadata objects between scheduler iterations
"""
"""Used to cache SamplingMetadata objects between scheduler iterations"""
def __init__(self):
self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
......@@ -124,12 +118,12 @@ class SamplingMetadata:
The first tuple is [1, 2] (sampled index within original logit),
and the second tuple is [0, 1] (sampled index within pruned logit).
num_prompts: Number of prompt sequence groups in seq_groups.
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
serialization of token outputs.
reuse_sampling_tensors: Indicates if we want to reuse sampling
reuse_sampling_tensors: Indicates if we want to reuse sampling
tensors that are part of the sampler forward pass. Currently,
it is mainly used for multi-step decode.
"""
def __init__(
......@@ -165,16 +159,19 @@ class SamplingMetadata:
num_prompts,
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
device, generators, cache)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=device,
pin_memory=pin_memory)
selected_token_indices = async_tensor_h2d(
selected_token_indices,
dtype=torch.long,
target_device=device,
pin_memory=pin_memory,
)
categorized_sample_indices = {
t: maybe_expand_dim(
async_tensor_h2d(seq_ids,
dtype=torch.int,
target_device=device,
pin_memory=pin_memory), 2, 2)
t: async_tensor_h2d(
seq_ids,
dtype=torch.int,
target_device=device,
pin_memory=pin_memory,
)
for t, seq_ids in categorized_sample_indices.items()
}
......@@ -201,8 +198,8 @@ def _prepare_seq_groups(
device: str,
generators: Optional[Dict[str, torch.Generator]] = None,
cache: Optional[SamplingMetadataCache] = None,
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
SamplingType, List[Tuple[int, int]]], int]:
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[SamplingType,
List[int]], int, ]:
"""Prepare sequence groups and indices for sampling.
Args:
......@@ -233,16 +230,13 @@ def _prepare_seq_groups(
# Sampling type -> (
# indices to sample/prompt logprob within pruned output logits,
# indices to sample within pruned logits)
categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
categorized_sample_indices: Dict[SamplingType, List[int]] = {
t: []
for t in SamplingType
}
# Index of logits to compute logprob. Logits include both prompt logprob
# and sample logprob indices.
logit_idx = 0
# Index to sample from a sample tensor. It is used by triton sample kernel.
# See `_sample_with_triton_kernel` for more details.
sample_idx = 0
# Total number of prompts from given sequence groups.
num_prompts = 0
......@@ -264,10 +258,10 @@ def _prepare_seq_groups(
# If the current seq group is in decode stage, it is None.
seq_len: Optional[int] = None
query_len: Optional[int] = None
prompt_logprob_indices: List[int] = \
sample_obj.prompt_logprob_indices if cache is not None else []
sample_indices: List[int] = \
sample_obj.sample_indices if cache is not None else []
prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices
if cache is not None else [])
sample_indices: List[int] = (sample_obj.sample_indices
if cache is not None else [])
do_sample = seq_group_metadata.do_sample
if seq_group_metadata.is_prompt:
......@@ -333,11 +327,8 @@ def _prepare_seq_groups(
if do_sample:
sample_indices.extend(range(logit_idx, logit_idx + sample_len))
categorized_sample_indices[sampling_params.sampling_type].extend(
list(
zip(range(logit_idx, logit_idx + sample_len),
range(sample_idx, sample_idx + sample_len))))
list(range(logit_idx, logit_idx + sample_len)))
logit_idx += sample_len
sample_idx += sample_len
if cache is not None:
sample_obj.sampling_params = sampling_params
......@@ -356,7 +347,8 @@ def _prepare_seq_groups(
generator=generator,
is_prompt=is_prompt,
prompt_logprob_indices=list(prompt_logprob_indices),
sample_indices=list(sample_indices))
sample_indices=list(sample_indices),
)
seq_groups.append(sample_obj)
......@@ -378,9 +370,6 @@ class SamplingTensors:
presence_penalties: torch.Tensor
frequency_penalties: torch.Tensor
repetition_penalties: torch.Tensor
sampling_seeds: torch.Tensor
sample_indices: torch.Tensor
extra_seeds: Optional[torch.Tensor]
prompt_tokens: torch.Tensor
output_tokens: torch.Tensor
......@@ -391,15 +380,7 @@ class SamplingTensors:
vocab_size: int,
device: torch.device,
dtype: torch.dtype,
*,
extra_seeds_to_generate: int = 0,
extra_entropy: Optional[Tuple[int, ...]] = None
) -> Tuple["SamplingTensors", bool, bool, bool]:
"""
extra_seeds_to_generate: extra seeds to generate using the
user-defined seed for each sequence.
extra_entropy: extra entropy to use when generating seeds.
"""
prompt_tokens: List[array] = []
output_tokens: List[array] = []
top_ks: List[int] = []
......@@ -409,19 +390,10 @@ class SamplingTensors:
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
repetition_penalties: List[float] = []
sampling_seeds: List[int] = []
sample_indices: List[int] = []
do_penalties = False
do_top_p_top_k = False
do_min_p = False
if _USE_TRITON_SAMPLER:
prompt_best_of: List[int] = []
# We need one base seed per Triton slice.
seeds_to_generate = (extra_seeds_to_generate +
get_num_triton_sampler_splits(vocab_size))
assert sampling_metadata.seq_groups is not None
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
......@@ -452,7 +424,7 @@ class SamplingTensors:
do_penalties = True
is_prompt = seq_group.is_prompt
if (is_prompt and sampling_params.prompt_logprobs is not None):
if is_prompt and sampling_params.prompt_logprobs is not None:
# For tokens in the prompt that we only need to get
# their logprobs
query_len = seq_group.query_len
......@@ -477,28 +449,6 @@ class SamplingTensors:
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)
if _USE_TRITON_SAMPLER:
if is_prompt:
prompt_best_of.append(sampling_params.best_of)
query_len = seq_group.query_len
assert query_len is not None
seed = sampling_params.seed
is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]
extra_entropy = extra_entropy or ()
seq_seeds = cls._get_sequence_seeds(
seed,
seq_data.get_len(),
*extra_entropy,
seq_id,
seeds_to_generate=seeds_to_generate,
is_greedy=is_greedy)
sampling_seeds.append(seq_seeds)
sample_indices.extend(seq_group.sample_indices)
if do_penalties:
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
......@@ -518,23 +468,37 @@ class SamplingTensors:
output_tokens.append(seq_data.output_token_ids_array)
sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, min_ps, presence_penalties,
frequency_penalties, repetition_penalties, sampling_seeds,
sample_indices, prompt_tokens, output_tokens, vocab_size,
extra_seeds_to_generate, device, dtype)
temperatures,
top_ps,
top_ks,
min_ps,
presence_penalties,
frequency_penalties,
repetition_penalties,
prompt_tokens,
output_tokens,
vocab_size,
device,
dtype,
)
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
@classmethod
def from_lists(cls, temperatures: List[float], top_ps: List[float],
top_ks: List[int], min_ps: List[float],
presence_penalties: List[float],
frequency_penalties: List[float],
repetition_penalties: List[float],
sampling_seeds: List[int], sample_indices: List[int],
prompt_tokens: List[array], output_tokens: List[array],
vocab_size: int, extra_seeds_to_generate: int,
device: torch.device,
dtype: torch.dtype) -> "SamplingTensors":
def from_lists(
cls,
temperatures: List[float],
top_ps: List[float],
top_ks: List[int],
min_ps: List[float],
presence_penalties: List[float],
frequency_penalties: List[float],
repetition_penalties: List[float],
prompt_tokens: List[array],
output_tokens: List[array],
vocab_size: int,
device: torch.device,
dtype: torch.dtype,
) -> "SamplingTensors":
# Note that the performance will be very bad without
# pinned memory.
pin_memory = is_pin_memory_available()
......@@ -603,34 +567,9 @@ class SamplingTensors:
dtype=torch.int,
pin_memory=pin_memory,
)
sample_indices_t = torch.tensor(
sample_indices,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
# need to transpose and make contiguous to
# copy the tensor correctly.
# [batch_size, n_seeds] -> [n_seeds, batch_size]
sampling_seeds_t = torch.tensor(
sampling_seeds,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
).t().contiguous()
# Because the memory is pinned, we can do non-blocking
# transfer to device.
# How many seeds the sample operation itself will need.
num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
sampling_seeds_gpu = sampling_seeds_t.to(device=device,
non_blocking=True)
extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
if not extra_seeds_gpu.numel():
extra_seeds_gpu = None
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
return cls(
temperatures=temperatures_t.to(device=device, non_blocking=True),
top_ps=top_ps_t.to(device=device, non_blocking=True),
......@@ -644,38 +583,4 @@ class SamplingTensors:
non_blocking=True),
prompt_tokens=prompt_t.to(device=device, non_blocking=True),
output_tokens=output_t.to(device=device, non_blocking=True),
sampling_seeds=sampling_seeds_gpu,
sample_indices=sample_indices_t.to(device=device,
non_blocking=True),
extra_seeds=extra_seeds_gpu,
)
@staticmethod
def _get_sequence_seeds(
seed: int,
*extra_entropy: int,
seeds_to_generate: int,
is_greedy: bool,
):
"""Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
if not is_greedy:
if seed is None:
randint_fn = random.randint
else:
generator = random.Random(str((seed, ) + extra_entropy))
randint_fn = generator.randint
lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
# If the user/random sets seed = 0 but request should
# have sampling, we need to change it to something
# else. We use a constant in that case.
# This way we don't need to create and load a bool
# matrix in the sampling kernel, which reduces CPU
# overhead and latency.
seq_seeds = [
randint_fn(lo, hi) or _SEED_0_REPLACEMENT
for _ in range(seeds_to_generate)
]
else:
# For the kernel, seed == 0 means greedy decoding.
seq_seeds = [0] * seeds_to_generate
return seq_seeds
"""Utils for model executor."""
import random
from typing import Any, Dict, Optional
import numpy as np
import torch
from vllm.utils import seed_everything
def set_random_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
seed_everything(seed)
def set_weight_attrs(
......
......@@ -14,7 +14,8 @@ from typing_extensions import TypeAlias
from vllm.config import ModelConfig
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.utils import JSONTree, is_list_of, json_map_leaves
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of,
json_map_leaves)
logger = init_logger(__name__)
......@@ -53,6 +54,12 @@ class MultiModalInputs(_MultiModalInputsBase):
if isinstance(nested_tensors, torch.Tensor):
return nested_tensors
if isinstance(nested_tensors, np.ndarray):
return torch.from_numpy(nested_tensors)
if isinstance(nested_tensors, (int, float)):
return torch.tensor(nested_tensors)
stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
if not is_list_of(stacked, torch.Tensor, check="all"):
# Only tensors (not lists) can be stacked.
......@@ -256,11 +263,20 @@ class MultiModalPlugin(ABC):
model_cls, _ = get_model_architecture(model_config)
mapper = self._input_mappers.get(model_cls)
# Only get processor kwargs at mapping time if we are not using the
# input mapper; no overrides are used on the default here because they
# should be passed to the huggingface resource at initialization time.
if mapper is not None and mapper != self._default_input_mapper:
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
mapper, overrides=model_config.mm_processor_kwargs)
else:
mm_processor_kwargs = {}
if mapper is None:
raise KeyError(f"No input mapper in {self} is registered for "
f"model class {model_cls.__name__}.")
return mapper(InputContext(model_config), data)
return mapper(InputContext(model_config), data, **mm_processor_kwargs)
@abstractmethod
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
......@@ -333,7 +349,10 @@ class MultiModalPlugin(ABC):
f"for model class {model_cls.__name__} in {self}.")
if callable(max_mm_tokens):
max_mm_tokens = max_mm_tokens(InputContext(model_config))
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
max_mm_tokens, overrides=model_config.mm_processor_kwargs)
max_mm_tokens = max_mm_tokens(InputContext(model_config),
**mm_processor_kwargs)
self._validate_max_multimodal_tokens(max_mm_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment