Unverified Commit 86222a3d authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[VLM] Merged multi-modal processor for GLM4V (#12449)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent fe743b79
...@@ -719,7 +719,7 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -719,7 +719,7 @@ See [this page](#generative-models) for more information on how to use generativ
* `THUDM/glm-4v-9b` etc. * `THUDM/glm-4v-9b` etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
* * ✅︎
- * `H2OVLChatModel` - * `H2OVLChatModel`
* H2OVL * H2OVL
* T + I<sup>E+</sup> * T + I<sup>E+</sup>
......
...@@ -106,7 +106,9 @@ def run_glm4v(question: str, modality: str): ...@@ -106,7 +106,9 @@ def run_glm4v(question: str, modality: str):
trust_remote_code=True, trust_remote_code=True,
enforce_eager=True, enforce_eager=True,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
prompt = question prompt = f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
{question}<|assistant|>"
stop_token_ids = [151329, 151336, 151338] stop_token_ids = [151329, 151336, 151338]
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
......
...@@ -147,6 +147,7 @@ def _test_processing_correctness( ...@@ -147,6 +147,7 @@ def _test_processing_correctness(
"facebook/chameleon-7b", "facebook/chameleon-7b",
"deepseek-ai/deepseek-vl2-tiny", "deepseek-ai/deepseek-vl2-tiny",
"adept/fuyu-8b", "adept/fuyu-8b",
"THUDM/glm-4v-9b",
"h2oai/h2ovl-mississippi-800m", "h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B", "OpenGVLab/InternVL2-1B",
"HuggingFaceM4/Idefics3-8B-Llama3", "HuggingFaceM4/Idefics3-8B-Llama3",
......
...@@ -4,20 +4,21 @@ ...@@ -4,20 +4,21 @@
# https://github.com/THUDM/CogAgent # https://github.com/THUDM/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights.""" """Inference-only CogAgent model compatible with THUDM weights."""
from argparse import Namespace from argparse import Namespace
from array import array from typing import (Iterable, List, Mapping, Optional, Sequence, Set, Tuple,
from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, Union)
TypedDict)
import torch import torch
from PIL import Image
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from transformers import PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -35,73 +36,55 @@ from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel ...@@ -35,73 +36,55 @@ from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs, from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
NestedTensors) from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, BaseProcessingInfo, BatchFeature,
SequenceData) BoundPromptReplacement,
MultiModalFieldConfig,
PlaceholderFeaturesInfo,
PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix, merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
IMAGE_TOKEN_ID = 151329
def calculate_image_placeholder(vision_config):
return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2
def build_normalization_transform(image_size: int) -> transforms.Compose:
"""
Build a normalization transform which can be applied to one or
more input images from which we want to extract visual features.
def mm_input_mapper_for_glmv( Args:
ctx: InputContext, image_size: size of the image to be processed for visual embeddings.
data: ModalityData[object],
) -> Dict: Returns:
model_config = ctx.model_config Callable transform for normalizing and resizing one RGB image.
tokenizer = cached_get_tokenizer( """
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code) return transforms.Compose([
if tokenizer is None: transforms.Resize(
raise RuntimeError("No HuggingFace processor is available " (image_size, image_size),
"to process the image object") interpolation=InterpolationMode.BICUBIC,
try: ),
raw_batch_data = tokenizer.apply_chat_template( transforms.ToTensor(),
conversation=[{ transforms.Normalize(
"role": "user", (0.48145466, 0.4578275, 0.40821073),
"image": data (0.26862954, 0.26130258, 0.27577711),
}], ),
add_generation_prompt=True, ])
tokenize=True,
return_tensors="pt",
return_dict=True).data def calculate_image_placeholder(vision_config):
except Exception: return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2
logger.error("Failed to process image (%s)", data)
raise
pixel_values = raw_batch_data['images']
return MultiModalKwargs({'pixel_values': pixel_values})
def merge_glm_vision_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
boi_token_id: int,
eoi_token_id: int,
) -> torch.Tensor:
boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0]
eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0]
mask = torch.zeros_like(input_ids, dtype=torch.bool)
for boi_pos, eoi_pos in zip(boi_positions, eoi_positions):
assert boi_pos < eoi_pos
mask[boi_pos:eoi_pos + 1] = True
inputs_embeds[mask] = vision_embeddings.view(-1,
vision_embeddings.shape[-1])
return inputs_embeds
class GLMImagePixelInputs(TypedDict): class GLMImagePixelInputs(TypedDict):
...@@ -109,120 +92,177 @@ class GLMImagePixelInputs(TypedDict): ...@@ -109,120 +92,177 @@ class GLMImagePixelInputs(TypedDict):
"""Shape: `(batch_size, num_channels, height, width)`""" """Shape: `(batch_size, num_channels, height, width)`"""
def get_max_glmv_image_tokens(ctx: InputContext): class GLM4VProcessor:
hf_config = ctx.get_hf_config(ChatGLMConfig) """
This model doesn't define its own HF processor,
so we implement our own one here.
vision_config = getattr(hf_config, 'vision_config', None) """
if vision_config is None:
return 1
elif isinstance(vision_config, dict):
return calculate_image_placeholder(vision_config)
msg = f"Unsupported vision config: {type(vision_config)}" def __init__(
raise NotImplementedError(msg) self,
config: ChatGLMConfig,
tokenizer: PreTrainedTokenizer,
) -> None:
super().__init__()
self.config = config
self.tokenizer = tokenizer
def dummy_data_for_glmv(ctx: InputContext, seq_len: int, if hasattr(self.config, "vision_config"):
mm_counts: Mapping[str, int]) -> DummyData: self.image_transform = build_normalization_transform(
hf_config = ctx.get_hf_config(ChatGLMConfig) config.vision_config["image_size"])
vision_config = getattr(hf_config, 'vision_config', None) else:
self.image_transform = None
if vision_config is None: def __call__(
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len) self,
seq_data = SequenceData(token_ids) text: Optional[Union[TextInput, list[TextInput]]] = None,
return DummyData(seq_data, None) images: Optional[Union[ImageInput, list[ImageInput]]] = None,
elif isinstance(vision_config, dict): return_tensors: Optional[Union[str, TensorType]] = None,
image_size = vision_config["image_size"] ) -> BatchFeature:
image_placeholder_length = calculate_image_placeholder(vision_config) if text is None:
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] + text = []
[0] * image_placeholder_length + if not isinstance(text, list):
[hf_config.eoi_token_id]) text = [text]
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, if images is None:
[0] * (seq_len - image_placeholder_length - 2)) images = []
seq_data = SequenceData(token_ids) if not isinstance(images, list):
images = [images]
text_inputs = self.tokenizer(text)
if len(images) == 0:
image_inputs = {}
else:
if self.image_transform is None:
raise ValueError("This model does not support image inputs")
pixel_values = [self.image_transform(image) for image in images]
image_inputs = {"pixel_values": torch.stack(pixel_values)}
return BatchFeature(
{
**text_inputs,
**image_inputs,
},
tensor_type=return_tensors,
)
mm_data = {
"image": Image.new("RGB", (image_size, image_size), color=0)
}
return DummyData(seq_data, mm_data) class GLM4VProcessingInfo(BaseProcessingInfo):
msg = f"Unsupported vision config: {type(vision_config)}" def __init__(self, ctx):
raise NotImplementedError(msg) super().__init__(ctx)
self._pre_calculate()
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def find_all_positions(input_ids: List[int], target: int) -> List[int]: def get_mm_max_tokens_per_item(
return [index for index, value in enumerate(input_ids) if value == target] self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.image_token_num + 2}
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs): def _pre_calculate(self):
multi_modal_data = inputs.get("multi_modal_data") hf_config = self.get_hf_config()
if multi_modal_data is None or "image" not in multi_modal_data: vision_config = hf_config.vision_config
return inputs self.image_token_num = calculate_image_placeholder(vision_config)
self.image_size = vision_config["image_size"]
hf_config = ctx.get_hf_config(ChatGLMConfig) def get_num_image_tokens(self) -> int:
vision_config = getattr(hf_config, 'vision_config', None) return self.image_token_num + 2
if vision_config is None: def get_image_size(self) -> ImageSize:
return inputs
elif isinstance(vision_config, dict):
image_placeholder_length = calculate_image_placeholder(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
input_ids = inputs["prompt_token_ids"]
tokenizer = cached_get_tokenizer(
ctx.model_config.model,
trust_remote_code=ctx.model_config.trust_remote_code)
try:
raw_batch_data = tokenizer.apply_chat_template(
conversation=[{
"role": "user",
"image": multi_modal_data["image"],
"content": inputs['prompt'],
}],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
).data
except Exception:
logger.error("Failed to process content (%s)", inputs['prompt'])
raise
input_ids = raw_batch_data['input_ids'][0].tolist()
boi_token_id = hf_config.boi_token_id return ImageSize(height=self.image_size, width=self.image_size)
eoi_token_id = hf_config.eoi_token_id
boi_positions = find_all_positions(input_ids, boi_token_id)
eoi_positions = find_all_positions(input_ids, eoi_token_id)
assert len(boi_positions) == len(eoi_positions) def get_hf_processor(self) -> GLM4VProcessor:
return GLM4VProcessor(
self.get_hf_config(),
self.get_tokenizer(),
)
new_input_ids = []
final_processed_position = 0
for boi_position, eoi_position in zip(boi_positions, eoi_positions): class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
assert boi_position < eoi_position
new_input_ids.extend(input_ids[final_processed_position:boi_position + def get_dummy_processor_inputs(
1]) self,
new_input_ids.extend([input_ids[boi_position + 1]] * seq_len: int,
image_placeholder_length) mm_counts: Mapping[str, int],
final_processed_position = eoi_position ) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_image_size()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
return ProcessorInputs(
prompt_text=text,
mm_data=mm_data,
)
new_input_ids.extend(input_ids[final_processed_position:])
prompt = inputs.get("prompt") class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
if prompt is None:
prompt = tokenizer.decode(new_input_ids) def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
return token_inputs( def _get_prompt_replacements(
prompt_token_ids=new_input_ids, self,
prompt=prompt, mm_items: MultiModalDataItems,
multi_modal_data=multi_modal_data, hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
def get_replacement(item_idx: int):
image_tokens = self.info.image_token_num
return [IMAGE_TOKEN_ID] * image_tokens
return [
PromptReplacement(
modality="image",
target=[IMAGE_TOKEN_ID],
replacement=get_replacement,
),
]
def _apply_prompt_replacements(
self,
token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids,
mm_prompt_repls=mm_prompt_repls,
mm_item_counts=mm_item_counts,
) )
hf_config = self.info.get_hf_config()
boi_token_id = hf_config.boi_token_id
eoi_token_id = hf_config.eoi_token_id
placeholders = {
modality: [
PlaceholderFeaturesInfo(
modality=p.modality,
item_idx=p.item_idx,
start_idx=p.start_idx - 1,
tokens=[boi_token_id] + p.tokens + [eoi_token_id],
) for p in ps
]
for modality, ps in placeholders.items()
}
return token_ids, text, placeholders
class GLMAttention(nn.Module): class GLMAttention(nn.Module):
...@@ -572,12 +612,16 @@ class ChatGLMModel(nn.Module): ...@@ -572,12 +612,16 @@ class ChatGLMModel(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.embedding(input_ids) inputs_embeds = self.embedding(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
inputs_embeds = merge_glm_vision_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids, input_ids=input_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
vision_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
boi_token_id=self.config.boi_token_id, placeholder_token_id=[
eoi_token_id=self.config.eoi_token_id) self.config.boi_token_id,
IMAGE_TOKEN_ID,
self.config.eoi_token_id,
],
)
return inputs_embeds return inputs_embeds
def forward( def forward(
...@@ -593,14 +637,12 @@ class ChatGLMModel(nn.Module): ...@@ -593,14 +637,12 @@ class ChatGLMModel(nn.Module):
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility. # condition is for v0 compatibility.
if intermediate_tensors is None and inputs_embeds is None: if intermediate_tensors is not None:
inputs_embeds = intermediate_tensors["hidden_states"]
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings) vision_embeddings)
input_ids = None
else:
inputs_embeds = intermediate_tensors["hidden_states"]
# Run encoder. # Run encoder.
hidden_states = self.encoder( hidden_states = self.encoder(
hidden_states=inputs_embeds, hidden_states=inputs_embeds,
...@@ -763,11 +805,21 @@ class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal): ...@@ -763,11 +805,21 @@ class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):
connector="transformer.vision.linear_proj", connector="transformer.vision.linear_proj",
tower_model="transformer.vision.transformer") tower_model="transformer.vision.transformer")
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
return self.transformer.get_multimodal_embeddings(**kwargs)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids,
multimodal_embeddings)
@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv) @MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor,
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens) info=GLM4VProcessingInfo,
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv) dummy_inputs=GLM4VDummyInputsBuilder)
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
SupportsMultiModal): SupportsMultiModal):
# Ensure that the LoRA support check passes when the class is not # Ensure that the LoRA support check passes when the class is not
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment