Commit 2216a4e5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/main'

parents ad385667 51c24c97
# -*- coding: utf-8 -*-
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.internlm2 import (InternLM2Attention,
InternLM2ForCausalLM,
InternLM2MLP, InternLM2Model)
from vllm.sequence import IntermediateTensors
from .utils import make_layers
class InternLM2VEDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.attention = InternLM2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
self.feed_forward = InternLM2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.feed_forward_ve = InternLM2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.attention_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
visual_token_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.attention_norm(hidden_states)
else:
hidden_states, residual = self.attention_norm(
hidden_states, residual)
hidden_states = self.attention(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.ffn_norm(hidden_states, residual)
if visual_token_mask is not None and visual_token_mask.any():
visual_token_mask = visual_token_mask.repeat(
1, self.hidden_size).bool()
text_token_mask = ~visual_token_mask
hidden_states[visual_token_mask] = self.feed_forward_ve(
hidden_states[visual_token_mask].reshape(
-1, self.hidden_size)).flatten()
if text_token_mask.any():
hidden_states[text_token_mask] = self.feed_forward(
hidden_states[text_token_mask].reshape(
-1, self.hidden_size)).flatten()
else:
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
class InternLM2VEModel(InternLM2Model):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, cache_config, quant_config)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: InternLM2VEDecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
visual_token_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.tok_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
visual_token_mask=visual_token_mask,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class InternLM2VEForCausalLM(InternLM2ForCausalLM):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__(config, cache_config, quant_config)
self.model = InternLM2VEModel(config, cache_config, quant_config)
......@@ -21,7 +21,8 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
......@@ -427,13 +428,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version
vision_feature_layer = self.select_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
self.vision_model = self._init_vision_model(config, num_hidden_layers)
self.llm_arch_name = config.text_config.architectures[0]
self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
self.vision_model = self._init_vision_model(config, self.is_mono)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
......@@ -451,10 +448,19 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return Sampler()
def _init_vision_model(self, config: PretrainedConfig,
num_hidden_layers: int):
return InternVisionModel(config.vision_config,
num_hidden_layers_override=num_hidden_layers)
def _init_vision_model(self, config: PretrainedConfig, is_mono: bool):
if not is_mono:
vision_feature_layer = self.select_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
return InternVisionModel(
config.vision_config,
num_hidden_layers_override=num_hidden_layers)
else:
return InternVisionPatchModel(config.vision_config)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
vit_hidden_size = config.vision_config.hidden_size
......@@ -562,6 +568,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return image_embeds
def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
if self.is_mono:
visual_token_mask = (
input_ids == self.img_context_token_id).reshape(-1, 1)
else:
visual_token_mask = None
return visual_token_mask
def forward(
self,
input_ids: torch.Tensor,
......@@ -574,6 +588,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
visual_token_mask = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
......@@ -583,16 +598,24 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id)
visual_token_mask = self._get_visual_token_mask(input_ids)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
visual_token_mask = None
forward_kwargs = {
"input_ids": input_ids,
"positions": positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
"intermediate_tensors": intermediate_tensors,
"inputs_embeds": inputs_embeds,
}
if self.is_mono:
forward_kwargs.update({"visual_token_mask": visual_token_mask})
hidden_states = self.language_model.model(**forward_kwargs)
return hidden_states
def compute_logits(
......
......@@ -281,13 +281,7 @@ class LlamaDecoderLayer(nn.Module):
return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": 0,
"inputs_embeds": 0,
"intermediate_tensors": 0,
})
@support_torch_compile
class LlamaModel(nn.Module):
def __init__(
......
......@@ -5,7 +5,8 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
SiglipVisionConfig)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
......@@ -22,6 +23,10 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens,
input_processor_for_clip)
from .interfaces import SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
dummy_seq_data_for_pixtral_hf,
get_max_pixtral_hf_image_tokens,
input_processor_for_pixtral_hf)
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
......@@ -31,8 +36,13 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
class LlavaImageEmbeddingInputs(TypedDict):
......@@ -77,6 +87,8 @@ def get_max_llava_image_tokens(ctx: InputContext):
num_image_tokens = get_max_clip_image_tokens(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_image_tokens = get_max_siglip_image_tokens(vision_config)
elif isinstance(vision_config, PixtralVisionConfig):
num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
......@@ -120,6 +132,17 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data
elif isinstance(vision_config, PixtralVisionConfig):
seq_data = dummy_seq_data_for_pixtral_hf(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
......@@ -163,6 +186,15 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, PixtralVisionConfig):
# We ignore image_feature_size_override since we have non-uniform
# image sizes for Pixtral
return input_processor_for_pixtral_hf(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
......@@ -189,6 +221,9 @@ def _init_vision_tower(hf_config: LlavaConfig):
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, PixtralVisionConfig):
# TODO: allow layer override?
return PixtralHFVisionModel(vision_config)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
......@@ -210,6 +245,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.config = config
self.multimodal_config = multimodal_config
# NOTE: These are special cases for Pixtral-12B in the HF-format
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
if (config.text_config.architectures is None
and config.text_config.model_type == "mistral"):
config.text_config.architectures = ["MistralForCausalLM"]
if (config.projector_hidden_act is None
and config.vision_config.hidden_act == "gelu"):
config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = _init_vision_tower(config)
self.multi_modal_projector = LlavaMultiModalProjector(
......@@ -243,9 +287,38 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return data
def _validate_image_sizes(self, images: List[torch.Tensor],
sizes: List[torch.Tensor]) -> List[torch.Tensor]:
if not isinstance(sizes, list):
sizes = [sizes]
total_images = sum(size.numel() // 2 for size in sizes)
if total_images != len(images):
raise ValueError("Mismatch in number of images. "
f"Expected {total_images}, got {len(images)}")
img_idx = 0
for size in sizes:
# Flatten the size tensor to a list of (height, width) pairs
size = size.view(-1, 2).tolist()
for expected_h, expected_w in size:
if img_idx >= len(images):
raise ValueError("Ran out of images before sizes. "
f"{img_idx} >= {len(images)}")
img = images[img_idx]
if img.shape[-2:] != (expected_h, expected_w):
raise ValueError(
"Image size mismatch. Expected "
f"{(expected_h, expected_w)}, got {img.shape[-2:]}")
if img.shape[-3] != 3:
raise ValueError("Image channel mismatch. Expected 3, "
f"got {img.shape[-3]}")
img_idx += 1
return images
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
......@@ -256,6 +329,34 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# Case for models like PixtralHF that have dynamic image sizes
# so we need to produce a list of tensors
if image_sizes is not None:
images = pixel_values
def flatten_to_3d_tensors(item):
if isinstance(item, torch.Tensor):
if item.dim() >= 3:
return [t for t in item.view(-1, *item.shape[-3:])]
else:
raise ValueError(
f"Unexpected tensor dimension: {item.dim()}")
elif isinstance(item, list):
return [
t for subitem in item
for t in flatten_to_3d_tensors(subitem)
]
else:
raise ValueError(f"Unexpected type: {type(item)}")
# Restructure the batched images into a list of lists of images
images = flatten_to_3d_tensors(pixel_values)
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_image_sizes(images, image_sizes),
)
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
......@@ -286,7 +387,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
PixtralHFVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
......
......@@ -13,11 +13,13 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip,
......@@ -28,8 +30,8 @@ from .llava import LlavaMultiModalProjector
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
init_vllm_registered_model)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
......@@ -312,6 +314,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
......@@ -605,14 +611,12 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
inputs_embeds = embed_multimodal(
input_ids,
self.config.image_token_index,
self.language_model.model.get_input_embeddings,
lambda _: self._process_image_input(image_input),
)
input_ids = None
else:
inputs_embeds = None
......@@ -641,6 +645,13 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
......@@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.interfaces import (HasInnerState,
......@@ -59,7 +59,7 @@ class MambaMixer(nn.Module):
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = config.intermediate_size
self.time_step_rank = int(config.time_step_rank)
self.is_falcon_mamba = config.model_type == "falcon_mamba"
self.conv1d = ColumnParallelLinear(
input_size=self.conv_kernel_size,
output_size=self.intermediate_size,
......@@ -109,6 +109,13 @@ class MambaMixer(nn.Module):
input_is_parallel=True,
)
self.activation = config.hidden_act
if self.is_falcon_mamba:
self.dt_layernorm = RMSNorm(self.time_step_rank,
eps=config.mixer_rms_eps)
self.b_layernorm = RMSNorm(self.ssm_state_size,
eps=config.mixer_rms_eps)
self.c_layernorm = RMSNorm(self.ssm_state_size,
eps=config.mixer_rms_eps)
def forward(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
......@@ -158,8 +165,12 @@ class MambaMixer(nn.Module):
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1,
)
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
# Note that Jamba and FalconMamba normalizes B, C, and time_step here
# but Mamba doesn't.
if self.is_falcon_mamba:
time_step = self.dt_layernorm(time_step.contiguous())
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
......@@ -213,11 +224,9 @@ class MambaDecoderLayer(nn.Module):
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.is_falcon_mamba = config.model_type == "falcon_mamba"
self.mixer = MambaMixer(config, layer_idx)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
def forward(
self,
......@@ -319,8 +328,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = self.backbone.embeddings
if config.tie_word_embeddings:
self.lm_head = self.backbone.embeddings
else:
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
......@@ -398,7 +417,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
for name, loaded_weight in weights:
if "A_log" in name:
name = name.replace("A_log", "A")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
......
......@@ -102,8 +102,9 @@ class PhiAttention(nn.Module):
# pylint: disable=C0301
# Refer to:
# https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
rope_theta = 10000
max_position_embeddings = getattr(config, "n_positions", 2048)
rope_theta = getattr(config, "rope_theta", 10000.0)
max_position_embeddings = getattr(config, "max_position_embeddings",
2048)
self.rotary_emb = get_rope(
self.head_size,
rotary_dim=rotary_dim,
......
......@@ -467,8 +467,6 @@ def input_processor_for_phi3v(ctx: InputContext,
prompt_token_ids = inputs["prompt_token_ids"].copy()
print("prompt_token_ids (old)", prompt_token_ids)
# masked placeholder with image token id
for idx in image_idx:
candidates = _get_image_placeholder_token_id_candidates(model_config,
......
......@@ -3,18 +3,25 @@ from functools import cached_property
from itertools import tee
from typing import Iterable, List, Mapping, Optional, Tuple, Union
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk
from PIL import Image
from transformers import PretrainedConfig
from transformers import PixtralVisionConfig, PretrainedConfig
from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens)
from transformers.models.pixtral.modeling_pixtral import (
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
from xformers.ops.fmha import memory_efficient_attention
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
......@@ -25,6 +32,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import init_vllm_registered_model
......@@ -576,3 +585,380 @@ class VisionLanguageAdapter(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_out(self.gelu(self.w_in(x)))
#### HF Transformers version of Pixtral ####
# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
# This model follows the Llava family, meaning image embeddings are placed
# instead of the `[IMG]` token placeholders.
# The model uses [`PixtralVisionModel`] for its vision encoder,
# and [`MistralForCausalLM`] for its language decoder.
def get_pixtral_hf_patch_grid_length(*, image_size: int,
patch_size: int) -> int:
# Since interpolation is applied, the image size need not be divisible
# assert image_size % patch_size == 0
return image_size // patch_size
def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int:
grid_length = get_pixtral_hf_patch_grid_length(image_size=image_size,
patch_size=patch_size)
return grid_length * grid_length
def get_max_pixtral_hf_image_feature_size(
hf_config: PixtralVisionConfig) -> int:
return get_pixtral_hf_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)
def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int:
return get_max_pixtral_hf_image_feature_size(hf_config)
def dummy_seq_data_for_pixtral_hf(
hf_config: PixtralVisionConfig,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
def dummy_image_for_pixtral_hf(
hf_config: PixtralVisionConfig,
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
image_width: int,
image_height: int) -> Tuple[int, int]:
# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 # noqa: E501
max_width, max_height = hf_config.image_size, hf_config.image_size
patch_width, patch_height = hf_config.patch_size, hf_config.patch_size
ratio = max(image_width / max_width, image_height / max_height)
if ratio > 1:
image_width = int(numpy.ceil(image_width / ratio))
image_height = int(numpy.ceil(image_height / ratio))
num_height_tokens, num_width_tokens = _num_image_tokens(
(image_height, image_width), (patch_height, patch_width))
return num_width_tokens, num_height_tokens
def input_processor_for_pixtral_hf(
model_config: ModelConfig,
hf_config: PixtralVisionConfig,
inputs: DecoderOnlyInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[Union[int, List[int]]] = None,
) -> DecoderOnlyInputs:
assert image_feature_size_override is None, (
"image_feature_size_override is not supported for Pixtral")
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
processor = cached_get_processor(model_config.model)
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_data = [image_data]
elif not is_list_of(image_data, Image.Image):
raise TypeError(f"Invalid image type: {type(image_data)}")
new_prompt = inputs.get("prompt")
new_token_ids = inputs["prompt_token_ids"]
image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token
# Update new_prompt if present
if new_prompt:
parts = new_prompt.split(image_token)
assert len(parts) - 1 == len(image_data)
new_parts = [parts[0]] # Start with the part before any image tokens
for image, next_part in zip(image_data, parts[1:]):
w, h = image.size
(num_width_tokens,
num_height_tokens) = get_pixtral_hf_image_feature_size(
hf_config, image_width=w, image_height=h)
replace_tokens = [image_token] * num_width_tokens + [
image_break_token
]
replace_tokens = replace_tokens * num_height_tokens
replace_tokens[-1] = image_end_token
new_parts.append("".join(replace_tokens))
new_parts.append(next_part)
new_prompt = "".join(new_parts)
# Update new_token_ids
convert_tokens_to_ids = processor.tokenizer.convert_tokens_to_ids
image_token_id = convert_tokens_to_ids(image_token)
image_break_id = convert_tokens_to_ids(image_break_token)
image_end_id = convert_tokens_to_ids(image_end_token)
placeholder_token_id = -999
# Find all image token indices at once
placeholder_indices = [
idx for idx, token_id in enumerate(new_token_ids)
if token_id == image_token_id
]
assert len(placeholder_indices) == len(image_data)
replace_tokens_list = []
for placeholder_idx, image in zip(placeholder_indices, image_data):
new_token_ids[placeholder_idx] = placeholder_token_id
w, h = image.size
(num_width_tokens,
num_height_tokens) = get_pixtral_hf_image_feature_size(hf_config,
image_width=w,
image_height=h)
replace_tokens = [image_token_id] * num_width_tokens + [image_break_id]
replace_tokens = replace_tokens * num_height_tokens
replace_tokens[-1] = image_end_id
replace_tokens_list.append(replace_tokens)
# Backward iteration for replacement without affecting known indices
for placeholder_idx, replace_tokens in zip(reversed(placeholder_indices),
reversed(replace_tokens_list)):
new_token_ids[placeholder_idx:placeholder_idx + 1] = replace_tokens
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
class PixtralHFMLP(nn.Module):
def __init__(self, config: PixtralVisionConfig):
super().__init__()
assert config.intermediate_size is not None
self.gate_proj = nn.Linear(config.hidden_size,
config.intermediate_size,
bias=False)
self.up_proj = nn.Linear(config.hidden_size,
config.intermediate_size,
bias=False)
self.down_proj = nn.Linear(config.intermediate_size,
config.hidden_size,
bias=False)
self.act = get_act_fn(config.hidden_act)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
class PixtralHFAttention(nn.Module):
def __init__(self, config: PixtralVisionConfig):
super().__init__()
self.config = config
assert not config.hidden_size % config.num_attention_heads
self.n_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.scale = self.head_dim**-0.5
self.q_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.k_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.v_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.o_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: BlockDiagonalMask,
position_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch, patches, _ = hidden_states.size()
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# Transpose q and k to apply HF's Rotary Position Embedding
q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
# Transpose q and k back for attention
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
out = memory_efficient_attention(q, k, v, attn_bias=attention_mask)
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
return self.o_proj(out)
class PixtralHFTransformerBlock(nn.Module):
def __init__(self, config: PixtralVisionConfig):
super().__init__()
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
self.attention = PixtralHFAttention(config)
self.feed_forward = PixtralHFMLP(config)
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: BlockDiagonalMask,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(hidden_states),
attention_mask=attention_mask,
position_embeddings=position_embeddings)
h = hidden_states + r
r = self.feed_forward.forward(self.ffn_norm(h))
out = h + r
return out
class PixtralHFTransformer(nn.Module):
def __init__(self, config: PixtralVisionConfig):
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(config.num_hidden_layers):
self.layers.append(PixtralHFTransformerBlock(config))
def forward(
self,
x: torch.Tensor,
attention_mask: BlockDiagonalMask,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
for layer in self.layers:
x = layer(x, attention_mask, position_embeddings)
return x
class PixtralHFVisionModel(nn.Module):
def __init__(self, config: PixtralVisionConfig):
super().__init__()
self.config = config
self.patch_conv = nn.Conv2d(
in_channels=config.num_channels,
out_channels=config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size,
bias=False,
)
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
self.transformer = PixtralHFTransformer(config)
self.dtype = next(self.parameters()).dtype
self.device = next(self.parameters()).device
self.patch_positional_embedding = PixtralRotaryEmbedding(
config, self.device)
def forward(
self,
pixel_values: List[torch.Tensor],
) -> torch.Tensor:
"""
Args:
pixel_values: Each image to be processed will be a separate tensor
in pixel_values. This means it will be a list of tensors
because multiple requests batched can have multiple images,
each with their own shape potentially
Returns:
image_features: tensor of token features for
all tokens of all images of shape (N_toks, D)
"""
# pass images through initial convolution independently
patch_embeds_list = [
self.patch_conv(img.unsqueeze(0).to(self.dtype))
for img in pixel_values
]
# flatten to a single sequence
patch_embeds = torch.cat(
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings
position_ids = position_ids_in_meshgrid(
patch_embeds_list,
max_width=self.config.image_size // self.config.patch_size).to(
self.device)
position_embedding = self.patch_positional_embedding(
patch_embeds, position_ids)
attention_mask = BlockDiagonalMask.from_seqlens(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
out = self.transformer(patch_embeds, attention_mask,
position_embedding)
return out
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = []
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
......@@ -482,6 +482,28 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules = {}
embedding_padding_modules = []
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
config: Qwen2Config,
......
......@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import lru_cache, partial
from functools import partial
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Tuple, Type, TypedDict, Union)
......@@ -63,7 +63,7 @@ from vllm.multimodal.base import MultiModalData
from vllm.multimodal.image import cached_get_image_processor
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import get_processor
from vllm.transformers_utils.processor import cached_get_processor
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (PPMissingLayer, get_vit_attn_backend,
......@@ -78,7 +78,7 @@ logger = init_logger(__name__)
class Qwen2VLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape:
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
"""
......@@ -102,14 +102,14 @@ Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
class Qwen2VLVideoInputs(TypedDict):
pixel_values_videos: torch.Tensor
"""Shape:
`(num_patches,
"""Shape:
`(num_patches,
num_channels * temporal_patch_size * patch_size * patch_size)`
"""
video_grid_thw: torch.Tensor
"""Shape: `(num_videos, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
......@@ -544,8 +544,6 @@ class Qwen2VisionTransformer(nn.Module):
# === Vision input helpers === #
cached_get_processor = lru_cache(get_processor)
def mm_input_mapper_for_qwen2_vl(
ctx: InputContext,
......
......@@ -47,12 +47,14 @@ _TEXT_GENERATION_MODELS = {
"GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
......@@ -87,10 +89,12 @@ _TEXT_GENERATION_MODELS = {
_EMBEDDING_MODELS = {
# [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"),
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"MistralModel": ("llama", "LlamaEmbeddingModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
# [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
}
......
import itertools
from dataclasses import dataclass, field
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Protocol, Tuple, Union, overload)
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Protocol, Tuple, Union, overload)
import torch
import torch.nn as nn
......@@ -21,7 +21,7 @@ from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import NestedTensors
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import is_cpu, is_pin_memory_available
from vllm.utils import is_pin_memory_available
logger = init_logger(__name__)
......@@ -294,10 +294,11 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
_embedding_count_expression(inner) for inner in embeddings)
def merge_multimodal_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
placeholder_token_id: int) -> torch.Tensor:
def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor,
multimodal_embeddings: NestedTensors,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
......@@ -306,8 +307,7 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
Note:
This updates ``inputs_embeds`` in place.
"""
mask = (input_ids == placeholder_token_id)
num_expected_tokens = mask.sum().item()
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
flattened = _flatten_embeddings(multimodal_embeddings)
......@@ -317,10 +317,70 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = flattened
inputs_embeds[is_multimodal] = flattened
return inputs_embeds
def embed_multimodal(
input_ids: torch.Tensor,
multimodal_token_id: int,
get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
get_multimodal_embeds: Callable[[torch.Tensor], Union[torch.Tensor,
List[torch.Tensor]]],
) -> torch.Tensor:
"""
Embed token IDs and multimodal inputs and combine their embeddings.
``multimodal_token_id`` is used to determine whether a token ID should
be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.
Compared to ``merge_multimodal_embeddings`, this avoids running
``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
which causes issues when the placeholder token ID exceeds the
vocabulary size of the language model.
"""
is_multimodal = input_ids == multimodal_token_id
is_text = ~is_multimodal
text_embeds = get_text_embeds(input_ids[is_text])
multimodal_embeds = get_multimodal_embeds(input_ids[is_multimodal])
merged_embeds = torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
)
merged_embeds[is_text] = text_embeds
return _merge_multimodal_embeddings(
merged_embeds,
is_multimodal,
multimodal_embeds,
)
def merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
placeholder_token_id: int,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
Note:
This updates ``inputs_embeds`` in place.
"""
return _merge_multimodal_embeddings(
inputs_embeds,
(input_ids == placeholder_token_id),
multimodal_embeddings,
)
class LayerFn(Protocol):
def __call__(self, prefix: str) -> torch.nn.Module:
......@@ -474,7 +534,7 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
class LLMWrapper(nn.Module):
"""
To align with the key names of LoRA trained with PEFT, we need to add an
To align with the key names of LoRA trained with PEFT, we need to add an
additional layer to the llm's implementation.
"""
......@@ -515,7 +575,7 @@ def get_vit_attn_backend() -> _Backend:
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend.")
selected_backend = _Backend.XFORMERS
elif is_cpu():
elif current_platform.is_cpu():
selected_backend = _Backend.TORCH_SDPA
else:
selected_backend = _Backend.XFORMERS
......
import time
from dataclasses import dataclass
from typing import List, Optional
from typing import Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Union
from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceStatus)
SequenceGroup, SequenceGroupBase, SequenceStatus)
@dataclass
......@@ -93,7 +92,7 @@ class RequestOutput:
def __init__(
self,
request_id: str,
prompt: Optional[PromptType],
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],
......@@ -115,14 +114,28 @@ class RequestOutput:
self.encoder_prompt_token_ids = encoder_prompt_token_ids
@classmethod
def from_seq_group(cls, seq_group: SequenceGroup,
use_cache: bool) -> Optional["RequestOutput"]:
def from_seq_group(
cls, seq_group: SequenceGroup, use_cache: bool,
seq_id_to_seq_group: Dict[str, SequenceGroupBase]
) -> Optional["RequestOutput"]:
finished = seq_group.is_finished()
if seq_group.request_id in seq_id_to_seq_group:
group: SequenceGroupBase = seq_id_to_seq_group[
seq_group.request_id]
if finished:
group.finish_seq(seq_group)
assembled_seq_group = group.maybe_assemble_group(seq_group)
if assembled_seq_group is None:
return None
return cls.from_seq_group(assembled_seq_group, use_cache,
seq_id_to_seq_group)
sampling_params = seq_group.sampling_params
if sampling_params is None:
raise ValueError(
"Sampling parameters are missing for a CompletionRequest.")
finished = seq_group.is_finished()
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
not finished):
return None
......@@ -137,15 +150,7 @@ class RequestOutput:
outputs=[],
finished=False)
seqs = seq_group.get_seqs()
if len(seqs) == 1:
top_n_seqs = seqs
else:
# Get the top-n sequences.
n = sampling_params._real_n or sampling_params.n
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
top_n_seqs = seq_group.get_seqs()
# Create the outputs.
# NOTE: We need omit logprobs here explicitly because the sequence
......@@ -209,7 +214,7 @@ class RequestOutput:
else:
output = CompletionOutput(
seqs.index(seq), output_text, [output_token_ids]
top_n_seqs.index(seq), output_text, [output_token_ids]
if isinstance(output_token_ids, int) else output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None,
output_logprobs,
......@@ -310,10 +315,13 @@ class EmbeddingRequestOutput:
class RequestOutputFactory:
@staticmethod
def create(seq_group: SequenceGroup, use_cache: bool = False):
def create(seq_group: SequenceGroup,
seq_id_to_seq_group: Dict[str, SequenceGroupBase],
use_cache: bool = False):
# Determine the type based on a condition, for example:
if hasattr(seq_group,
'embeddings') and seq_group.embeddings is not None:
return EmbeddingRequestOutput.from_seq_group(seq_group)
else:
return RequestOutput.from_seq_group(seq_group, use_cache)
return RequestOutput.from_seq_group(seq_group, use_cache,
seq_id_to_seq_group)
......@@ -61,6 +61,13 @@ try:
except Exception:
pass
is_neuron = False
try:
import transformers_neuronx # noqa: F401
is_neuron = True
except ImportError:
pass
if is_tpu:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
......@@ -78,6 +85,9 @@ elif is_xpu:
elif is_cpu:
from .cpu import CpuPlatform
current_platform = CpuPlatform()
elif is_neuron:
from .neuron import NeuronPlatform
current_platform = NeuronPlatform()
else:
current_platform = UnspecifiedPlatform()
......
......@@ -10,6 +10,7 @@ class PlatformEnum(enum.Enum):
TPU = enum.auto()
XPU = enum.auto()
CPU = enum.auto()
NEURON = enum.auto()
UNSPECIFIED = enum.auto()
......@@ -48,6 +49,9 @@ class Platform:
def is_cpu(self) -> bool:
return self._enum == PlatformEnum.CPU
def is_neuron(self) -> bool:
return self._enum == PlatformEnum.NEURON
def is_cuda_alike(self) -> bool:
"""Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
......
from .interface import Platform, PlatformEnum
class NeuronPlatform(Platform):
_enum = PlatformEnum.NEURON
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return "neuron"
......@@ -49,14 +49,17 @@ class GuidedDecodingParams:
@staticmethod
def from_optional(
json: Optional[Union[Dict, BaseModel, str]],
json: Optional[Union[Dict, BaseModel, str]] = None,
regex: Optional[str] = None,
choice: Optional[List[str]] = None,
grammar: Optional[str] = None,
json_object: Optional[bool] = None,
backend: Optional[str] = None,
whitespace_pattern: Optional[str] = None,
) -> "GuidedDecodingParams":
) -> Optional["GuidedDecodingParams"]:
if all(arg is None
for arg in (json, regex, choice, grammar, json_object)):
return None
# Extract json schemas from pydantic models
if isinstance(json, (BaseModel, type(BaseModel))):
json = json.model_json_schema()
......
from ._core_ext import NanRepr, ScalarType
import functools
import struct
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Union
# Mirrors enum in `core/scalar_type.hpp`
class NanRepr(Enum):
NONE = 0 # nans are not supported
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
# This ScalarType class is a parallel implementation of the C++ ScalarType
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
# in sync until the inductor fully supports custom C++ classes.
@dataclass(frozen=True)
class ScalarType:
"""
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
"""
exponent: int
"""
Number of bits in the exponent if this is a floating point type
(zero if this an integer type)
"""
mantissa: int
"""
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
this an integer type.
"""
signed: bool
"If the type is signed (i.e. has a sign bit)"
bias: int
"""
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""
_finite_values_only: bool = False
"""
Private: if infs are supported, used `has_infs()` instead.
"""
nan_repr: NanRepr = NanRepr.IEEE_754
"""
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
"""
def _floating_point_max_int(self) -> int:
assert (
self.mantissa <= 52 and self.exponent <= 11
), f"Cannot represent max/min as a double for type {self.__str__()}"
max_mantissa = (1 << self.mantissa) - 1
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
max_mantissa = max_mantissa - 1
max_exponent = (1 << self.exponent) - 2
if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
or self.nan_repr == NanRepr.NONE):
assert (
self.exponent < 11
), f"Cannot represent max/min as a double for type {self.__str__()}"
max_exponent = max_exponent + 1
# adjust the exponent to match that of a double
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
# e is the exponent bits), there is some precedent for non-standard
# biases, example `float8_e4m3b11fnuz` here:
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
# complication we are just assuming the standard exponent bias until
# there is a need to support non-standard biases
exponent_bias = (1 << (self.exponent - 1)) - 1
exponent_bias_double = (1 << 10) - 1 # double e = 11
max_exponent_double = (max_exponent - exponent_bias +
exponent_bias_double)
# shift the mantissa and exponent into the proper positions for an
# IEEE double and bitwise-or them together.
return (max_mantissa <<
(52 - self.mantissa)) | (max_exponent_double << 52)
def _floating_point_max(self) -> float:
double_raw = self._floating_point_max_int()
return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
def _raw_max(self) -> Union[int, float]:
if self.is_floating_point():
return self._floating_point_max()
else:
assert (self.size_bits < 64 or self.size_bits == 64
and self.is_signed()), "Cannot represent max as an int"
return (1 << self.mantissa) - 1
def _raw_min(self) -> Union[int, float]:
if self.is_floating_point():
assert self.is_signed(
), "We currently assume all floating point types are signed"
sign_bit_double = 1 << 63
max_raw = self._floating_point_max_int()
min_raw = max_raw | sign_bit_double
return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
else:
assert (not self.is_signed() or
self.size_bits <= 64), "Cannot represent min as a int64_t"
if self.is_signed():
return -(1 << (self.size_bits - 1))
else:
return 0
@functools.cached_property
def id(self) -> int:
"""
Convert the ScalarType to an int which can be passed to pytorch custom
ops. This layout of the int must be kept in sync with the C++
ScalarType's from_id method.
"""
val = 0
offset = 0
def or_and_advance(member, bit_width):
nonlocal val
nonlocal offset
bit_mask = (1 << bit_width) - 1
val = val | (int(member) & bit_mask) << offset
offset = offset + bit_width
or_and_advance(self.exponent, 8)
or_and_advance(self.mantissa, 8)
or_and_advance(self.signed, 1)
or_and_advance(self.bias, 32)
or_and_advance(self._finite_values_only, 1)
or_and_advance(self.nan_repr.value, 8)
assert offset <= 64, \
f"ScalarType fields too big {offset} to fit into an int64"
return val
@property
def size_bits(self) -> int:
return self.exponent + self.mantissa + int(self.signed)
def min(self) -> Union[int, float]:
"""
Min representable value for this scalar type.
(accounting for bias if there is one)
"""
return self._raw_min() - self.bias
def max(self) -> Union[int, float]:
"""
Max representable value for this scalar type.
(accounting for bias if there is one)
"""
return self._raw_max() - self.bias
def is_signed(self) -> bool:
"""
If the type is signed (i.e. has a sign bit), same as `signed`
added for consistency with:
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
"""
return self.signed
def is_floating_point(self) -> bool:
"If the type is a floating point type"
return self.exponent != 0
def is_integer(self) -> bool:
"If the type is an integer type"
return self.exponent == 0
def has_bias(self) -> bool:
"If the type has a non-zero bias"
return self.bias != 0
def has_infs(self) -> bool:
"If the type is floating point and supports infinity"
return not self._finite_values_only
def has_nans(self) -> bool:
return self.nan_repr != NanRepr.NONE.value
def is_ieee_754(self) -> bool:
"""
If the type is a floating point type that follows IEEE 754
conventions
"""
return self.nan_repr == NanRepr.IEEE_754.value and \
not self._finite_values_only
def __str__(self) -> str:
"""
naming generally follows: https://github.com/jax-ml/ml_dtypes
for floating point types (leading f) the scheme is:
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
flags:
- no-flags: means it follows IEEE 754 conventions
- f: means finite values only (no infinities)
- n: means nans are supported (non-standard encoding)
for integer types the scheme is:
`[u]int<size_bits>[b<bias>]`
- if bias is not present it means its zero
"""
if self.is_floating_point():
ret = "float" + str(self.size_bits) + "_e" + str(
self.exponent) + "m" + str(self.mantissa)
if not self.is_ieee_754():
if self._finite_values_only:
ret = ret + "f"
if self.nan_repr != NanRepr.NONE:
ret = ret + "n"
return ret
else:
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
if self.has_bias():
ret = ret + "b" + str(self.bias)
return ret
def __repr__(self) -> str:
return "ScalarType." + self.__str__()
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
# opcheck to work.
def __len__(self) -> int:
raise TypeError
#
# Convenience Constructors
#
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
"Create a signed integer scalar type (size_bits includes sign-bit)."
ret = cls(0, size_bits - 1, True, bias if bias else 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
"""Create a unsigned integer scalar type."""
ret = cls(0, size_bits, False, bias if bias else 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
"""
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
"""
assert (mantissa > 0 and exponent > 0)
ret = cls(exponent, mantissa, True, 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
nan_repr: NanRepr) -> 'ScalarType':
"""
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
"""
assert (mantissa > 0 and exponent > 0)
assert (nan_repr != NanRepr.IEEE_754), (
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions")
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
ret.id # noqa B018: make sure the id is cached
return ret
# naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f) the scheme is:
......@@ -17,14 +311,13 @@ class scalar_types:
uint4 = ScalarType.uint(4, None)
int8 = ScalarType.int_(8, None)
uint8 = ScalarType.uint(8, None)
float8_e4m3fn = ScalarType.float_(4, 3, True,
NanRepr.EXTD_RANGE_MAX_MIN.value)
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
# "gptq" types
uint2b2 = ScalarType.uint(2, 2)
......
......@@ -4,7 +4,7 @@ import enum
from abc import ABC, abstractmethod
from array import array
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import cached_property, reduce
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
......@@ -17,7 +17,7 @@ from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
if TYPE_CHECKING:
......@@ -1401,3 +1401,121 @@ class ExecuteModelRequest(
last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None,
async_callback=self.async_callback)
@dataclass
class SequenceGroupBase:
group_id: str # the original request id before splitting
assembled_seq_group: Optional[SequenceGroup] = None
# seq id to a unique index inside this group
seq_id_to_index: Dict[str, int] = field(default_factory=dict)
# seq ids to be finished
to_be_finished: Dict[str, SequenceGroup] = field(default_factory=dict)
# seq id to finished sequences
finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict)
streaming: bool = False
output_produced: bool = False
@staticmethod
def add_request(request_id: str, engine, params, *args, **kwargs):
"""When we are ready to add a request with request_id and params
into the engine, we can split the request into multiple requests.
"""
raise NotImplementedError
def finish_seq(self, seq: SequenceGroup):
"""The sequence `seq` finishes, we should record the information.
"""
del self.to_be_finished[seq.request_id]
self.finished_reqs[seq.request_id] = seq
def maybe_assemble_group(
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
"""Assemble the sequence group, for producing the final
output, or adding request in the engine again.
"""
raise NotImplementedError
class ParallelSampleSequenceGroup(SequenceGroupBase):
@staticmethod
def add_request(request_id: str, engine, params, **kwargs):
original_params = params
params = copy.deepcopy(original_params)
params.n = 1
group = ParallelSampleSequenceGroup(request_id)
seqs = []
for i in range(original_params.n):
request_id_i = f"{request_id}_parallel_sample_{i}"
group.seq_id_to_index[request_id_i] = i
seq_group = engine.add_request(
request_id_i,
params=params,
**kwargs,
) # type: ignore
assert seq_group is not None
engine.seq_id_to_seq_group[request_id_i] = group
group.to_be_finished[request_id_i] = seq_group
seqs.append(seq_group.seqs[0])
# for parallel sampling, the `assembled_seq_group` is always
# available, since we have all the sequences ready, and they
# will not change.
group.assembled_seq_group = SequenceGroup(
request_id=request_id,
seqs=seqs,
arrival_time=seq_group.arrival_time,
sampling_params=original_params,
lora_request=seq_group.lora_request,
embeddings=seq_group.embeddings,
pooling_params=seq_group.pooling_params,
encoder_seq=seq_group.encoder_seq,
trace_headers=seq_group.trace_headers,
prompt_adapter_request=seq_group.prompt_adapter_request,
priority=seq_group.priority,
)
group.streaming = params.output_kind == RequestOutputKind.DELTA
group.output_produced = False
def maybe_assemble_group(
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
# in the streaming mode, we will return the assembled sequence
# for the first sequence, and then return None for the rest of
# sequences
if self.streaming:
if self.seq_id_to_index[seq_group.request_id] == 0:
return self.assembled_seq_group
return None
# in the non-streaming mode, we will return the assembled sequence
# once after all sequences finish, and then return None for the
# rest of the time
if len(self.to_be_finished) > 0:
return None
assert self.assembled_seq_group is not None
params = self.assembled_seq_group.sampling_params
assert isinstance(params, SamplingParams)
if not self.output_produced:
self.output_produced = True
if params._real_n is not None:
# Get the top-n sequences.
n = params._real_n or params.n
seqs = self.assembled_seq_group.seqs
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
self.assembled_seq_group.seqs = top_n_seqs
return self.assembled_seq_group
if self.output_produced:
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment