Commit af7f4372 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1

parents 5e19cdef 09c77926
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
from torch import nn
......@@ -9,20 +10,19 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.gemma import GemmaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsVision
from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import merge_vision_embeddings
from .utils import merge_multimodal_embeddings
logger = init_logger(__name__)
......@@ -31,6 +31,25 @@ _KEYS_TO_MODIFY_MAPPING = {
}
class PaliGemmaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
class PaliGemmaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
PaliGemmaImageEmbeddingInputs]
def get_max_paligemma_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config
......@@ -38,17 +57,20 @@ def get_max_paligemma_image_tokens(ctx: InputContext):
return get_max_siglip_image_tokens(vision_config)
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int):
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
)
mm_data = dummy_image_for_siglip(vision_config)
mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data
......@@ -107,20 +129,11 @@ class PaliGemmaMultiModalProjector(nn.Module):
return hidden_states
class PaliGemmaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
PaliGemmaImageInputs = PaliGemmaImagePixelInputs
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: PaliGemmaConfig,
......@@ -163,19 +176,31 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return PaliGemmaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return PaliGemmaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _image_pixels_to_features(
self,
vision_tower: SiglipVisionModel,
......@@ -187,27 +212,21 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
return image_features
def _process_image_pixels(
def _process_image_input(
self,
inputs: PaliGemmaImagePixelInputs,
image_input: PaliGemmaImageInputs,
) -> torch.Tensor:
assert self.vision_tower is not None
pixel_values = inputs["data"]
if image_input["type"] == "image_embeds":
return image_input["data"]
return self._image_pixels_to_features(
assert self.vision_tower is not None
pixel_values = image_input["data"]
image_features = self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)
def _process_image_input(
self,
image_input: PaliGemmaImageInputs,
) -> torch.Tensor:
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input, )
return self.multi_modal_projector(image_features)
def forward(self,
......@@ -228,7 +247,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
......@@ -246,8 +265,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
return hidden_states
# Copied from vllm/model_executor/models/gemma.py
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.language_model.embed_tokens,
hidden_states, sampling_metadata)
return logits
......
......@@ -285,8 +285,11 @@ class PersimmonForCausalLM(nn.Module):
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -260,6 +260,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
super().__init__()
self.config = config
# lm_head use bias, cannot share word embeddings
assert not config.tie_word_embeddings
self.lora_config = lora_config
self.quant_config = quant_config
......@@ -286,8 +288,11 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits
......
......@@ -368,6 +368,8 @@ class Phi3SmallForCausalLM(nn.Module):
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -399,8 +401,11 @@ class Phi3SmallForCausalLM(nn.Module):
def get_decoder(self):
return self.model
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
if self.dummy_token_indices is not None and logits is not None:
......@@ -446,4 +451,3 @@ class Phi3SmallForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
self.lm_head.weight.data.copy_(self.model.embed_tokens.weight.data)
......@@ -15,7 +15,8 @@
# limitations under the License.
import re
from functools import lru_cache
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict, Union)
import numpy as np
import torch
......@@ -28,8 +29,7 @@ from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -37,13 +37,13 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings
logger = init_logger(__name__)
......@@ -70,6 +70,36 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
projection_dim=768)
class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""
image_sizes: torch.Tensor
"""
Shape: `(batch_size, 2)`
This should be in `(height, width)` format.
"""
class Phi3VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]
class Phi3ImageEmbeddingBase(nn.Module):
def __init__(self) -> None:
......@@ -189,7 +219,7 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
global_image_features_hd_newline = self.add_image_newline(
global_image_features_hd)
all_image_embeddings = []
batch_image_features_proj = []
# need a for loop to process each image because of different image sizes
# (patch arrangement is different for each image)
for i, img_size in enumerate(image_sizes):
......@@ -207,19 +237,17 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
sub_image_features_hd)
# [sub features, separator, global features]
all_image_embeddings.append(
torch.cat([
image_embeddings = torch.cat([
sub_image_features_hd_newline.squeeze(
0), # (h_crop*12*(w_crop*12+1), 4096)
self.glb_GN.squeeze(0),
global_image_features_hd_newline[i],
]))
])
img_proj = self.img_projection(
image_embeddings.to(target_device, target_dtype))
batch_image_features_proj.append(img_proj)
image_features_proj = self.img_projection(
torch.stack(all_image_embeddings).to(target_device, target_dtype)
) # (num_images, (h_crop*12*(w_crop*12+1)+1), hidden_size)
return image_features_proj
return batch_image_features_proj
def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
"""
......@@ -259,24 +287,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline
class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""
image_sizes: torch.Tensor
"""
Shape: `(batch_size, 2)`
This should be in `(height, width)` format.
"""
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
target_height = int(np.ceil(height / padding_unit) * padding_unit)
......@@ -314,12 +324,12 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
def get_phi3v_image_feature_size(
hf_config: PretrainedConfig,
hf_config: Dict[str, Any],
*,
input_height: int,
input_width: int,
) -> int:
num_crops = getattr(hf_config, "num_crops", 16)
num_crops = hf_config.get("num_crops", 16)
new_width, new_height = _calc_hd_transform_size(width=input_width,
height=input_height,
hd_num=num_crops)
......@@ -331,24 +341,28 @@ def get_phi3v_image_feature_size(
def get_max_phi3v_image_tokens(ctx: InputContext):
return get_phi3v_image_feature_size(
ctx.get_hf_config(PretrainedConfig),
ctx.get_hf_image_processor_config(),
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
image_feature_size = get_max_phi3v_image_tokens(ctx)
seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len,
num_images,
image_token_id=_IMAGE_TOKEN_ID,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
......@@ -381,7 +395,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_image_processor_config()
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
......@@ -392,7 +406,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
input_width=w,
input_height=h)
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
image_feature_size = image_data.shape[0]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
......@@ -443,7 +457,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsVision):
class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
def __init__(self,
config: PretrainedConfig,
......@@ -463,6 +477,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -496,13 +512,18 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
self, **kwargs: object) -> Optional[Phi3VImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None:
return None
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
......@@ -516,6 +537,31 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes))
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Phi3VImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: Phi3VImageInputs,
) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_embed_tokens is not None
image_embeds = self.vision_embed_tokens(image_input["data"],
image_input["image_sizes"])
return image_embeds
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
......@@ -526,11 +572,10 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self.vision_embed_tokens(
image_input["data"], image_input["image_sizes"])
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
input_ids = None
else:
......@@ -545,8 +590,11 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment