Commit af7f4372 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1

parents 5e19cdef 09c77926
...@@ -230,7 +230,7 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -230,7 +230,7 @@ class GPTNeoXForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config, config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
...@@ -243,6 +243,8 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -243,6 +243,8 @@ class GPTNeoXForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
) )
if self.config.tie_word_embeddings:
self.embed_out.weight = self.gpt_neox.embed_in.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -258,8 +260,11 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -258,8 +260,11 @@ class GPTNeoXForCausalLM(nn.Module):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.embed_out, hidden_states, logits = self.logits_processor(self.embed_out, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type, from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
Union, overload, runtime_checkable) Union, overload, runtime_checkable)
from typing_extensions import TypeGuard from typing_extensions import TypeIs
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -10,12 +10,12 @@ logger = init_logger(__name__) ...@@ -10,12 +10,12 @@ logger = init_logger(__name__)
@runtime_checkable @runtime_checkable
class SupportsVision(Protocol): class SupportsMultiModal(Protocol):
"""The interface required for all vision language models (VLMs).""" """The interface required for all multi-modal models."""
supports_vision: ClassVar[Literal[True]] = True supports_multimodal: ClassVar[Literal[True]] = True
""" """
A flag that indicates this model supports vision inputs. A flag that indicates this model supports multi-modal inputs.
Note: Note:
There is no need to redefine this flag if this class is in the There is no need to redefine this flag if this class is in the
...@@ -29,30 +29,31 @@ class SupportsVision(Protocol): ...@@ -29,30 +29,31 @@ class SupportsVision(Protocol):
# We can't use runtime_checkable with ClassVar for issubclass checks # We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead # so we need to treat the class as an instance and use isinstance instead
@runtime_checkable @runtime_checkable
class _SupportsVisionType(Protocol): class _SupportsMultiModalType(Protocol):
supports_vision: Literal[True] supports_multimodal: Literal[True]
def __call__(self, *, multimodal_config: MultiModalConfig) -> None: def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
... ...
@overload @overload
def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]: def supports_multimodal(
model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
... ...
@overload @overload
def supports_vision(model: object) -> TypeGuard[SupportsVision]: def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
... ...
def supports_vision( def supports_multimodal(
model: Union[Type[object], object], model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]: ) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, _SupportsVisionType) return isinstance(model, _SupportsMultiModalType)
return isinstance(model, SupportsVision) return isinstance(model, SupportsMultiModal)
@runtime_checkable @runtime_checkable
...@@ -94,18 +95,18 @@ class _SupportsLoRAType(Protocol): ...@@ -94,18 +95,18 @@ class _SupportsLoRAType(Protocol):
@overload @overload
def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]: def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]:
... ...
@overload @overload
def supports_lora(model: object) -> TypeGuard[SupportsLoRA]: def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
... ...
def supports_lora( def supports_lora(
model: Union[Type[object], object], model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]: ) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
result = _supports_lora(model) result = _supports_lora(model)
if not result: if not result:
...@@ -137,7 +138,7 @@ def supports_lora( ...@@ -137,7 +138,7 @@ def supports_lora(
def _supports_lora( def _supports_lora(
model: Union[Type[object], object], model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]: ) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, _SupportsLoRAType) return isinstance(model, _SupportsLoRAType)
...@@ -172,18 +173,18 @@ class _HasInnerStateType(Protocol): ...@@ -172,18 +173,18 @@ class _HasInnerStateType(Protocol):
@overload @overload
def has_inner_state(model: object) -> TypeGuard[HasInnerState]: def has_inner_state(model: object) -> TypeIs[HasInnerState]:
... ...
@overload @overload
def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]: def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]:
... ...
def has_inner_state( def has_inner_state(
model: Union[Type[object], object] model: Union[Type[object], object]
) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]: ) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, _HasInnerStateType) return isinstance(model, _HasInnerStateType)
......
...@@ -87,6 +87,7 @@ class InternLM2Attention(nn.Module): ...@@ -87,6 +87,7 @@ class InternLM2Attention(nn.Module):
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.key_value_groups = int(self.num_heads / self.num_kv_heads)
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
...@@ -120,6 +121,14 @@ class InternLM2Attention(nn.Module): ...@@ -120,6 +121,14 @@ class InternLM2Attention(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
def split_qkv(self, qkv: torch.Tensor):
qkv = qkv.view(-1, self.num_kv_heads, self.key_value_groups + 2, 128)
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=2)
q = q.reshape(-1, self.q_size)
k = k.reshape(-1, self.kv_size)
v = v.reshape(-1, self.kv_size)
return q, k, v
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -128,7 +137,7 @@ class InternLM2Attention(nn.Module): ...@@ -128,7 +137,7 @@ class InternLM2Attention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.wqkv(hidden_states) qkv, _ = self.wqkv(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = self.split_qkv(qkv)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.wo(attn_output) output, _ = self.wo(attn_output)
...@@ -264,6 +273,8 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -264,6 +273,8 @@ class InternLM2ForCausalLM(nn.Module):
self.output = ParallelLMHead(config.vocab_size, self.output = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
if self.config.tie_word_embeddings:
self.output.weight = self.model.tok_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -279,8 +290,11 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -279,8 +290,11 @@ class InternLM2ForCausalLM(nn.Module):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.output, hidden_states, logits = self.logits_processor(self.output, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
...@@ -319,24 +333,6 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -319,24 +333,6 @@ class InternLM2ForCausalLM(nn.Module):
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
param = params_dict[name] param = params_dict[name]
if "wqkv" in name: weight_loader = getattr(param, "weight_loader",
config = self.config default_weight_loader)
kv_groups = (config.num_attention_heads // weight_loader(param, loaded_weight)
config.num_key_value_heads)
head_dim = config.hidden_size // config.num_attention_heads
loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
head_dim,
loaded_weight.shape[-1])
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
dim=1)
wq = wq.reshape(-1, wq.shape[-1])
wk = wk.reshape(-1, wk.shape[-1])
wv = wv.reshape(-1, wv.shape[-1])
weight_loader = param.weight_loader
weight_loader(param, wq, 'q')
weight_loader(param, wk, 'k')
weight_loader(param, wv, 'v')
else:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
import itertools import itertools
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -18,18 +19,18 @@ from vllm.config import CacheConfig, MultiModalConfig ...@@ -18,18 +19,18 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches) get_clip_num_patches)
from .interfaces import SupportsVision from .interfaces import SupportsMultiModal
from .utils import merge_vision_embeddings from .utils import (filter_weights, init_vllm_registered_model,
merge_multimodal_embeddings)
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
...@@ -38,9 +39,6 @@ IMG_CONTEXT = '<IMG_CONTEXT>' ...@@ -38,9 +39,6 @@ IMG_CONTEXT = '<IMG_CONTEXT>'
IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225) IMAGENET_STD = (0.229, 0.224, 0.225)
MAX_IMAGE_FEATURE_SIZE_WIDTH = 3000
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 500
class InternVLImagePixelInputs(TypedDict): class InternVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
...@@ -53,6 +51,19 @@ class InternVLImagePixelInputs(TypedDict): ...@@ -53,6 +51,19 @@ class InternVLImagePixelInputs(TypedDict):
""" """
class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
InternVLImageInputs = Union[InternVLImagePixelInputs,
InternVLImageEmbeddingInputs]
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B # copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size): def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
...@@ -84,11 +95,9 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, ...@@ -84,11 +95,9 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
return best_ratio return best_ratio
def calculate_num_blocks(orig_width: int, def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
orig_height: int, max_num: int,
min_num=1, image_size: int) -> Tuple[int, int, int]:
max_num=6,
image_size=448):
aspect_ratio = orig_width / orig_height aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio # calculate the existing image aspect ratio
...@@ -110,11 +119,9 @@ def calculate_num_blocks(orig_width: int, ...@@ -110,11 +119,9 @@ def calculate_num_blocks(orig_width: int,
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def dynamic_preprocess(image, def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
min_num=1, image_size: int,
max_num=6, use_thumbnail: int) -> List[Image.Image]:
image_size=448,
use_thumbnail=False):
orig_width, orig_height = image.size orig_width, orig_height = image.size
blocks, target_width, target_height = calculate_num_blocks( blocks, target_width, target_height = calculate_num_blocks(
...@@ -138,12 +145,14 @@ def dynamic_preprocess(image, ...@@ -138,12 +145,14 @@ def dynamic_preprocess(image,
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def image_to_pixel_values(image: Image.Image, input_size=448, max_num=6): def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
max_num: int, use_thumbnail: bool) -> torch.Tensor:
transform = build_transform(input_size=input_size) transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, images = dynamic_preprocess(image,
min_num=min_num,
max_num=max_num,
image_size=input_size, image_size=input_size,
use_thumbnail=True, use_thumbnail=use_thumbnail)
max_num=max_num)
pixel_values = [transform(image) for image in images] pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values) pixel_values = torch.stack(pixel_values)
return pixel_values return pixel_values
...@@ -157,14 +166,20 @@ def get_internvl_num_patches(image_size: int, patch_size: int, ...@@ -157,14 +166,20 @@ def get_internvl_num_patches(image_size: int, patch_size: int,
def get_max_internvl_image_tokens(ctx: InputContext): def get_max_internvl_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PretrainedConfig) hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
use_thumbnail = hf_config.use_thumbnail
max_dynamic_patch = hf_config.max_dynamic_patch
if use_thumbnail:
max_dynamic_patch += 1
downsample_ratio = hf_config.downsample_ratio
image_size = vision_config.image_size image_size = vision_config.image_size
patch_size = vision_config.patch_size patch_size = vision_config.patch_size
downsample_ratio = hf_config.downsample_ratio
num_patches = get_internvl_num_patches(image_size, patch_size, num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio) downsample_ratio)
return num_patches * 7 return num_patches * max_dynamic_patch
def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
...@@ -173,24 +188,32 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -173,24 +188,32 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs return llm_inputs
model_config = ctx.model_config model_config = ctx.model_config
hf_config = ctx.get_hf_config(PretrainedConfig) hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
image_size = vision_config.image_size
patch_size = vision_config.patch_size
downsample_ratio = hf_config.downsample_ratio
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)
image_data = multi_modal_data["image"] image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image): if isinstance(image_data, Image.Image):
width, height = image_data.size width, height = image_data.size
num_blocks, _, _ = calculate_num_blocks(width, height) min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
max_num, image_size)
# add thumbnail image if num_blocks > 1
if hf_config.use_thumbnail and num_blocks > 1:
num_blocks += 1
image_feature_size = num_blocks * num_patches
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet") image_feature_size = image_data.shape[0]
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
image_size = vision_config.image_size
patch_size = vision_config.patch_size
downsample_ratio = hf_config.downsample_ratio
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)
tokenizer = cached_get_tokenizer(model_config.tokenizer, tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True) trust_remote_code=True)
...@@ -198,8 +221,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -198,8 +221,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
prompt_token_ids = llm_inputs["prompt_token_ids"] prompt_token_ids = llm_inputs["prompt_token_ids"]
if prompt is None: if prompt is None:
prompt = tokenizer.decode(prompt_token_ids) prompt = tokenizer.decode(prompt_token_ids)
image_prompt = IMG_START + IMG_CONTEXT * (num_blocks + image_prompt = IMG_START + IMG_CONTEXT * image_feature_size + IMG_END
1) * num_patches + IMG_END
new_prompt = prompt.replace('<image>', image_prompt, 1) new_prompt = prompt.replace('<image>', image_prompt, 1)
new_prompt_token_ids = tokenizer.encode(new_prompt) new_prompt_token_ids = tokenizer.encode(new_prompt)
...@@ -209,8 +231,19 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -209,8 +231,19 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
def input_mapper_for_internvl(ctx: InputContext, data: object): def input_mapper_for_internvl(ctx: InputContext, data: object):
hf_config = ctx.get_hf_config()
use_thumbnail = hf_config.use_thumbnail
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
image_size = hf_config.vision_config.image_size
if isinstance(data, Image.Image): if isinstance(data, Image.Image):
data = image_to_pixel_values(data) data = image_to_pixel_values(data,
image_size,
min_num,
max_num,
use_thumbnail=use_thumbnail)
model_config = ctx.model_config model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer, tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True) trust_remote_code=True)
...@@ -224,11 +257,13 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): ...@@ -224,11 +257,13 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
}) })
def dummy_data_for_internvl(ctx: InputContext, seq_len: int): def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
image_feature_size = get_max_internvl_image_tokens(ctx) image_feature_size = get_max_internvl_image_tokens(ctx)
model_config = ctx.model_config model_config = ctx.model_config
hf_config = ctx.get_hf_config(PretrainedConfig) hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
tokenizer = cached_get_tokenizer(model_config.tokenizer, tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True) trust_remote_code=True)
...@@ -236,14 +271,23 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int): ...@@ -236,14 +271,23 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
seq_data = dummy_seq_data_for_clip( seq_data = dummy_seq_data_for_clip(
vision_config, vision_config,
seq_len, seq_len,
num_images,
image_token_id=tokenizer.encode(IMG_CONTEXT, image_token_id=tokenizer.encode(IMG_CONTEXT,
add_special_tokens=False)[0], add_special_tokens=False)[0],
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
image_size = vision_config.image_size
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
max_image_width = max_num * image_size
max_image_height = min_num * image_size
mm_data = dummy_image_for_clip( mm_data = dummy_image_for_clip(
vision_config, vision_config,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, num_images,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, image_width_override=max_image_width,
image_height_override=max_image_height,
) )
return seq_data, mm_data return seq_data, mm_data
...@@ -253,7 +297,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int): ...@@ -253,7 +297,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl) @INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
class InternVLChatModel(nn.Module, SupportsVision): class InternVLChatModel(nn.Module, SupportsMultiModal):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
...@@ -283,10 +327,8 @@ class InternVLChatModel(nn.Module, SupportsVision): ...@@ -283,10 +327,8 @@ class InternVLChatModel(nn.Module, SupportsVision):
self.vision_model = InternVisionModel( self.vision_model = InternVisionModel(
config.vision_config, num_hidden_layers_override=num_hidden_layers) config.vision_config, num_hidden_layers_override=num_hidden_layers)
llm_class = ModelRegistry.load_model_cls( self.language_model = init_vllm_registered_model(
config.text_config.architectures[0]) config.text_config, cache_config, quant_config)
self.language_model = llm_class(config.text_config, cache_config,
quant_config)
vit_hidden_size = config.vision_config.hidden_size vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size llm_hidden_size = config.text_config.hidden_size
...@@ -356,23 +398,49 @@ class InternVLChatModel(nn.Module, SupportsVision): ...@@ -356,23 +398,49 @@ class InternVLChatModel(nn.Module, SupportsVision):
return data return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[InternVLImagePixelInputs]: self, **kwargs: object) -> Optional[InternVLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
image_token_id = kwargs.pop("image_token_id", None) image_token_id = kwargs.pop("image_token_id", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None: if pixel_values is None and image_embeds is None:
return None return None
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return InternVLImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
self.img_context_token_id = image_token_id[0] self.img_context_token_id = image_token_id[0]
if not isinstance(pixel_values, (torch.Tensor, list)): if pixel_values is not None:
raise ValueError("Incorrect type of pixel values. " if not isinstance(pixel_values, (torch.Tensor, list)):
f"Got type: {type(pixel_values)}") raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["data"])
return InternVLImagePixelInputs( return image_embeds
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
def forward( def forward(
self, self,
...@@ -387,10 +455,10 @@ class InternVLChatModel(nn.Module, SupportsVision): ...@@ -387,10 +455,10 @@ class InternVLChatModel(nn.Module, SupportsVision):
if image_input is not None: if image_input is not None:
inputs_embeds = self.language_model.model.get_input_embeddings( inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids) input_ids)
vit_embeds = self.extract_feature(image_input["data"]) vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, inputs_embeds = merge_multimodal_embeddings(
vit_embeds, input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id) self.img_context_token_id)
input_ids = None input_ids = None
else: else:
inputs_embeds = None inputs_embeds = None
...@@ -403,8 +471,11 @@ class InternVLChatModel(nn.Module, SupportsVision): ...@@ -403,8 +471,11 @@ class InternVLChatModel(nn.Module, SupportsVision):
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
...@@ -415,24 +486,16 @@ class InternVLChatModel(nn.Module, SupportsVision): ...@@ -415,24 +486,16 @@ class InternVLChatModel(nn.Module, SupportsVision):
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
prefix: str):
for name, loaded_weight in weights:
name = name.split(".")
if prefix == name.pop(0):
name = ".".join(name)
yield name, loaded_weight
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components # prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
# load vision encoder # load vision encoder
vit_weights = self._filter_weights(vit_weights, "vision_model") vit_weights = filter_weights(vit_weights, "vision_model")
self.vision_model.load_weights(vit_weights) self.vision_model.load_weights(vit_weights)
# load mlp projector # load mlp projector
mlp_weights = self._filter_weights(mlp_weights, "mlp1") mlp_weights = filter_weights(mlp_weights, "mlp1")
mlp_params_dict = dict(self.mlp1.named_parameters()) mlp_params_dict = dict(self.mlp1.named_parameters())
for name, loaded_weight in mlp_weights: for name, loaded_weight in mlp_weights:
param = mlp_params_dict[name] param = mlp_params_dict[name]
...@@ -441,5 +504,5 @@ class InternVLChatModel(nn.Module, SupportsVision): ...@@ -441,5 +504,5 @@ class InternVLChatModel(nn.Module, SupportsVision):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load llm backbone # load llm backbone
llm_weights = self._filter_weights(llm_weights, "language_model") llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights) self.language_model.load_weights(llm_weights)
...@@ -20,14 +20,14 @@ ...@@ -20,14 +20,14 @@
"""Inference-only Jais model compatible with HuggingFace weights.""" """Inference-only Jais model compatible with HuggingFace weights."""
import math import math
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -37,12 +37,14 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -37,12 +37,14 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs import JAISConfig from vllm.transformers_utils.configs import JAISConfig
from .utils import is_pp_missing_parameter, make_layers
class SwiGLUActivation(nn.Module): class SwiGLUActivation(nn.Module):
...@@ -216,6 +218,7 @@ class JAISModel(nn.Module): ...@@ -216,6 +218,7 @@ class JAISModel(nn.Module):
config: JAISConfig, config: JAISConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -231,10 +234,15 @@ class JAISModel(nn.Module): ...@@ -231,10 +234,15 @@ class JAISModel(nn.Module):
self.embeddings_scale = config.embeddings_scale self.embeddings_scale = config.embeddings_scale
else: else:
self.embeddings_scale = config.mup_embeddings_scale self.embeddings_scale = config.mup_embeddings_scale
self.h = nn.ModuleList([
JAISBlock(config, cache_config, quant_config) self.start_layer, self.end_layer, self.h = make_layers(
for _ in range(config.num_hidden_layers) config.num_hidden_layers,
]) lambda prefix: JAISBlock(config=config,
cache_config=cache_config,
quant_config=quant_config),
prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward( def forward(
...@@ -243,19 +251,29 @@ class JAISModel(nn.Module): ...@@ -243,19 +251,29 @@ class JAISModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds = self.wte(input_ids) ) -> Union[IntermediateTensors, torch.Tensor]:
if self.wpe is not None: if get_pp_group().is_first_rank:
position_embeds = self.wpe(position_ids) inputs_embeds = self.wte(input_ids)
hidden_states = inputs_embeds + position_embeds if self.wpe is not None:
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
hidden_states = inputs_embeds
hidden_states *= torch.tensor(float(self.embeddings_scale),
dtype=hidden_states.dtype)
else: else:
hidden_states = inputs_embeds assert intermediate_tensors is not None
hidden_states *= torch.tensor(float(self.embeddings_scale), hidden_states = intermediate_tensors["hidden_states"]
dtype=hidden_states.dtype)
for i in range(len(self.h)): for i in range(self.start_layer, self.end_layer):
layer = self.h[i] layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -273,7 +291,11 @@ class JAISLMHeadModel(nn.Module): ...@@ -273,7 +291,11 @@ class JAISLMHeadModel(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = JAISModel(config, cache_config, quant_config) self.transformer = JAISModel(config, cache_config, quant_config)
self.lm_head = self.transformer.wte if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
if hasattr(config, "width_scale"): if hasattr(config, "width_scale"):
self.output_logits_scale = config.width_scale self.output_logits_scale = config.width_scale
else: else:
...@@ -290,17 +312,30 @@ class JAISLMHeadModel(nn.Module): ...@@ -290,17 +312,30 @@ class JAISLMHeadModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[IntermediateTensors, torch.Tensor]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
...@@ -324,6 +359,10 @@ class JAISLMHeadModel(nn.Module): ...@@ -324,6 +359,10 @@ class JAISLMHeadModel(nn.Module):
continue continue
if not name.startswith("transformer."): if not name.startswith("transformer."):
name = "transformer." + name name = "transformer." + name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights. # Because of this, we need to transpose the weights.
......
...@@ -16,7 +16,6 @@ from vllm.attention.layer import Attention ...@@ -16,7 +16,6 @@ from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -249,37 +248,6 @@ class JambaMambaMixer(nn.Module): ...@@ -249,37 +248,6 @@ class JambaMambaMixer(nn.Module):
return hidden_states return hidden_states
class JambaMLP(nn.Module):
def __init__(
self,
config: JambaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
hidden_act = config.hidden_act
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class JambaMoE(nn.Module): class JambaMoE(nn.Module):
def __init__(self, def __init__(self,
...@@ -327,6 +295,21 @@ class JambaMoE(nn.Module): ...@@ -327,6 +295,21 @@ class JambaMoE(nn.Module):
return hidden_states.view(orig_shape) return hidden_states.view(orig_shape)
class JambaMLP(JambaMoE):
def __init__(self,
config: JambaConfig,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__(config,
num_experts=1,
top_k=1,
params_dtype=params_dtype,
tp_size=tp_size,
quant_config=quant_config)
class JambaMambaDecoderLayer(nn.Module): class JambaMambaDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
...@@ -609,12 +592,8 @@ class JambaForCausalLM(nn.Module, HasInnerState): ...@@ -609,12 +592,8 @@ class JambaForCausalLM(nn.Module, HasInnerState):
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
) )
# Current step used indices
self.current_indices: List[int] = []
# Used to track and store by the Mamba cache between steps. # Used to track and store by the Mamba cache between steps.
self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
# Used as an input_buffer for the CUDA graph runs.
self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple()
# Maps between the request id and a dict that maps between the seq_id # Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache # and its index inside the self.mamba_cache
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
...@@ -644,95 +623,148 @@ class JambaForCausalLM(nn.Module, HasInnerState): ...@@ -644,95 +623,148 @@ class JambaForCausalLM(nn.Module, HasInnerState):
batch_size = input_ids.shape[0] batch_size = input_ids.shape[0]
if attn_metadata.prefill_metadata: if attn_metadata.prefill_metadata:
batch_size = len(request_ids_to_seq_ids) batch_size = len(request_ids_to_seq_ids)
( mamba_cache = self._prepare_current_run_mamba_cache(
current_seqlen_agnostic_cache, request_ids_to_seq_ids, batch_size, finished_requests_ids)
indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
batch_size,
finished_requests_ids)
else: else:
# CUDA graph capturing runs # CUDA graph capturing runs
current_seqlen_agnostic_cache, indices = ( mamba_cache = kwargs["seqlen_agnostic_capture_inputs"]
kwargs["seqlen_agnostic_capture_inputs"],
[],
)
self.current_indices = indices
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, attn_metadata, mamba_cache[0],
current_seqlen_agnostic_cache[0], mamba_cache[1])
current_seqlen_agnostic_cache[1])
if "seqlen_agnostic_capture_inputs" not in kwargs:
self._copy_mamba_cache_by_indices(self.current_indices,
current_seqlen_agnostic_cache)
return hidden_states return hidden_states
def _copy_mamba_cache_by_indices( def _swap_mamba_cache(self, from_index: int, to_index: int):
self, indices: List[int], assert len(self.mamba_cache) > 0
current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]): for cache_t in self.mamba_cache:
for i, offset in enumerate(indices): cache_t[:, [to_index,from_index]] = \
self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache) cache_t[:, [from_index,to_index]]
def _copy_mamba_cache(self, index_to: int, index_from: int, def _copy_mamba_cache(self, from_index: int, to_index: int):
from_buffer: Tuple[torch.Tensor, torch.Tensor]):
assert len(self.mamba_cache) > 0 assert len(self.mamba_cache) > 0
for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer): for cache_t in self.mamba_cache:
cache_t[:, index_to].copy_(from_buffer_t[:, index_from], cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True) non_blocking=True)
def _assign_seq_id_to_mamba_cache(self, cur_rid: str, def _move_out_if_already_occupied(self, index: int,
seqs_id: List[int]) -> List[int]: all_occupied_indices: List[int]):
indices_for_current_run = [] if index in all_occupied_indices:
for seq_id in seqs_id: first_free_index = self._first_free_index_in_mamba_cache()
if cur_rid not in self.mamba_cache_indices_mapping: # In case occupied, move the occupied to a new empty block
self.mamba_cache_indices_mapping[cur_rid] = {} self._move_cache_index_and_mappings(from_index=index,
first_free_index = self._first_free_index_in_mamba_cache() to_index=first_free_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = first_free_index def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
index_for_current_run = first_free_index seq_id: int,
## case of decoding n>1, copy prefill cache to decoding indices destination_index: int):
elif seq_id not in (seq_ids2indices := """
self.mamba_cache_indices_mapping[cur_rid]): Assign (req_id,seq_id) pair to a `destination_index` index, if
first_free_index = self._first_free_index_in_mamba_cache() already occupied, move the occupying index to a free index.
index_exist = list(seq_ids2indices.values())[0] """
self._copy_mamba_cache(index_from=index_exist, all_occupied_indices = self._get_all_occupied_indices()
index_to=first_free_index, if cur_rid not in self.mamba_cache_indices_mapping:
from_buffer=self.mamba_cache) self._move_out_if_already_occupied(
self.mamba_cache_indices_mapping[cur_rid][ index=destination_index,
seq_id] = first_free_index all_occupied_indices=all_occupied_indices)
index_for_current_run = first_free_index self.mamba_cache_indices_mapping[cur_rid] = {
else: seq_id: destination_index
index_for_current_run = self.mamba_cache_indices_mapping[ }
cur_rid][seq_id] elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
indices_for_current_run.append(index_for_current_run) # parallel sampling , where n > 1, assume prefill have
return indices_for_current_run # already happened now we only need to copy the already
# existing cache into the siblings seq_ids caches
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
index_exists = list(seq_ids2indices.values())[0]
# case of decoding n>1, copy prefill cache to decoding indices
self._copy_mamba_cache(from_index=index_exists,
to_index=destination_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = destination_index
else:
# already exists
cache_index_already_exists = self.mamba_cache_indices_mapping[
cur_rid][seq_id]
if cache_index_already_exists != destination_index:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self._swap_pair_indices_and_mappings(
from_index=cache_index_already_exists,
to_index=destination_index)
def _prepare_current_run_mamba_cache( def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int, self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str] batch_size: int, finished_requests_ids: List[str]):
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: running_indices = []
indices_for_current_run = [] request_ids_to_seq_ids_flatten = [
for request_id, seqs_id in request_ids_to_seq_ids.items(): (req_id, seq_id)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
for dest_index, (request_id,
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
if request_id in finished_requests_ids: if request_id in finished_requests_ids:
# Do not allocate cache for requests that run # Do not allocate cache index for requests that run
# and finish right after # and finish right after
continue continue
indices_for_current_run += self._assign_seq_id_to_mamba_cache( self._assign_seq_id_to_mamba_cache_in_specific_dest(
request_id, seqs_id) request_id, seq_id, dest_index)
## Pad the batch in case of running batch that was not captured via CG running_indices.append(dest_index)
padded_indices = indices_for_current_run.copy()
pad_index = self._first_free_index_in_mamba_cache()
for _ in range(batch_size - len(indices_for_current_run)): self._clean_up_first_bs_blocks(batch_size, running_indices)
padded_indices.append(pad_index) conv_state = self.mamba_cache[0][:, :batch_size]
temporal_state = self.mamba_cache[1][:, :batch_size]
conv_state = self.mamba_cache[0][:, padded_indices] return (conv_state, temporal_state)
temporal_state = self.mamba_cache[1][:, padded_indices]
def _get_all_occupied_indices(self):
return [
cache_idx
for seq_ids2indices in self.mamba_cache_indices_mapping.values()
for cache_idx in seq_ids2indices.values()
]
return (conv_state, temporal_state), indices_for_current_run def _clean_up_first_bs_blocks(self, batch_size: int,
indices_for_current_run: List[int]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices = set([range(batch_size)])
max_possible_batch_size = self.mamba_cache[0].shape[1]
for destination_index in destination_indices:
if destination_index in self._get_all_occupied_indices() and \
destination_index not in indices_for_current_run:
# move not running indices outside of the batch
all_other_indices = list(
range(batch_size, max_possible_batch_size))
first_avail_index = self._first_free_index_in_mamba_cache(
all_other_indices)
self._swap_indices(from_index=destination_index,
to_index=first_avail_index)
def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
self._copy_mamba_cache(from_index=from_index, to_index=to_index)
self._update_mapping_index(from_index=from_index, to_index=to_index)
def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
self._swap_mamba_cache(from_index=from_index, to_index=to_index)
self._swap_mapping_index(from_index=from_index, to_index=to_index)
def _swap_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
elif to_index == index:
seq_ids2index.update({seq_id: from_index})
def _update_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
return
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
""" """
...@@ -747,28 +779,9 @@ class JambaForCausalLM(nn.Module, HasInnerState): ...@@ -747,28 +779,9 @@ class JambaForCausalLM(nn.Module, HasInnerState):
self._release_mamba_cache(finished_requests_ids) self._release_mamba_cache(finished_requests_ids)
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
cg_batch_size = input_buffers['input_ids'].shape[0] cg_batch_size = input_buffers['input_ids'].shape[0]
( self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
current_mamba_cache, cg_batch_size,
indices, finished_requests_ids)
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
cg_batch_size,
finished_requests_ids)
self.current_indices = indices
for input_buffer, current_cache_buffer in zip(
input_buffers["seqlen_agnostic_capture_inputs"],
current_mamba_cache):
input_buffer.copy_(current_cache_buffer, non_blocking=True)
def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant Mamba cache from the CUDA graph input_buffers
back to the JambaForCausalLM.mamba_cache after CUDA
graph replay run is done.
"""
self._copy_mamba_cache_by_indices(
self.current_indices,
input_buffers["seqlen_agnostic_capture_inputs"])
def get_seqlen_agnostic_capture_inputs(self, batch_size: int): def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
""" """
...@@ -776,26 +789,25 @@ class JambaForCausalLM(nn.Module, HasInnerState): ...@@ -776,26 +789,25 @@ class JambaForCausalLM(nn.Module, HasInnerState):
The buffer is used to maintain the Mamba Cache during the CUDA graph The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs. replay runs.
""" """
return tuple(buffer[:, :batch_size] return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
for buffer in self.mamba_gc_cache_buffer)
def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids: for req_id in finished_seq_groups_req_ids:
if req_id in self.mamba_cache_indices_mapping: if req_id in self.mamba_cache_indices_mapping:
self.mamba_cache_indices_mapping.pop(req_id) self.mamba_cache_indices_mapping.pop(req_id)
def _first_free_index_in_mamba_cache(self) -> int: def _first_free_index_in_mamba_cache(
if self.mamba_cache: self, indices_range: Optional[List[int]] = None) -> int:
assert self.mamba_cache is not None
if indices_range is None:
max_possible_batch_size = self.mamba_cache[0].shape[1] max_possible_batch_size = self.mamba_cache[0].shape[1]
occupied = [ indices_range = list(range(max_possible_batch_size))
id for seq_ids in self.mamba_cache_indices_mapping.values() all_occupied_indices = self._get_all_occupied_indices()
for id in seq_ids.values() for i in indices_range:
] if i not in all_occupied_indices:
first_free_index = [ return i
i not in occupied for i in range(max_possible_batch_size) raise Exception("Couldn't find a free spot in the mamba cache! This"
].index(True) "should never happen")
return first_free_index
return 0
def _get_mamba_cache_shape( def _get_mamba_cache_shape(
self self
...@@ -819,23 +831,24 @@ class JambaForCausalLM(nn.Module, HasInnerState): ...@@ -819,23 +831,24 @@ class JambaForCausalLM(nn.Module, HasInnerState):
[layer_type == "mamba" for layer_type in layers_type]) [layer_type == "mamba" for layer_type in layers_type])
max_batch_size = (_get_graph_batch_size( max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config else self.scheduler_config.max_num_seqs) if self.scheduler_config else
max(_BATCH_SIZES_TO_CAPTURE)) + 10 max(_BATCH_SIZES_TO_CAPTURE) + 2)
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
assert conv_state_shape is not None and temporal_state_shape is not None assert conv_state_shape is not None and temporal_state_shape is not None
for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) +
buffer = (torch.empty(size=(mamba_layers, max_batch_size) + conv_state_shape,
conv_state_shape, dtype=dtype,
dtype=dtype, device="cuda"),
device="cuda"), torch.empty(size=(mamba_layers, max_batch_size) +
torch.empty(size=(mamba_layers, max_batch_size) + temporal_state_shape,
temporal_state_shape, dtype=dtype,
dtype=dtype, device="cuda"))
device="cuda"))
setattr(self, buffername, buffer) def compute_logits(
self,
def compute_logits(self, hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
...@@ -854,8 +867,6 @@ class JambaForCausalLM(nn.Module, HasInnerState): ...@@ -854,8 +867,6 @@ class JambaForCausalLM(nn.Module, HasInnerState):
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
] ]
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
...@@ -877,6 +888,10 @@ class JambaForCausalLM(nn.Module, HasInnerState): ...@@ -877,6 +888,10 @@ class JambaForCausalLM(nn.Module, HasInnerState):
if ".self_attn." in name: if ".self_attn." in name:
name = name.replace(".self_attn", "") name = name.replace(".self_attn", "")
if "feed_forward" in name and not _is_moe_layer(name):
## map MLP layers to expert with ID=0
name = name.replace("feed_forward", "feed_forward.experts.0")
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
...@@ -891,10 +906,15 @@ class JambaForCausalLM(nn.Module, HasInnerState): ...@@ -891,10 +906,15 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
for mapping in expert_params_mapping: for (
param_name, weight_name, expert_id, shard_id = mapping param_name,
weight_name,
expert_id,
shard_id,
) in expert_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
...@@ -913,3 +933,11 @@ class JambaForCausalLM(nn.Module, HasInnerState): ...@@ -913,3 +933,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
def _is_moe_layer(name: str):
return any(
[experts_name in name for experts_name in [
"experts",
"router",
]])
...@@ -145,6 +145,7 @@ class LlamaAttention(nn.Module): ...@@ -145,6 +145,7 @@ class LlamaAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim, input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size, output_size=hidden_size,
...@@ -153,12 +154,17 @@ class LlamaAttention(nn.Module): ...@@ -153,12 +154,17 @@ class LlamaAttention(nn.Module):
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
) )
is_neox_style = True
if quant_config is not None and quant_config.get_name() == "gguf":
is_neox_style = False
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
) )
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
...@@ -291,6 +297,7 @@ class LlamaModel(nn.Module): ...@@ -291,6 +297,7 @@ class LlamaModel(nn.Module):
self.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
quant_config=quant_config,
) )
else: else:
self.embed_tokens = PPMissingLayer() self.embed_tokens = PPMissingLayer()
...@@ -444,8 +451,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -444,8 +451,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return model_output return model_output
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import CLIPVisionConfig, LlavaConfig from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (CLIPVisionModel, dummy_image_for_clip,
get_max_clip_image_tokens, input_processor_for_clip) dummy_seq_data_for_clip, get_max_clip_image_tokens,
from .interfaces import SupportsVision input_processor_for_clip)
from .utils import merge_vision_embeddings from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
merge_multimodal_embeddings)
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head", class LlavaImagePixelInputs(TypedDict):
"language_model.model": "language_model", type: Literal["pixel_values"]
} data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
# TODO(xwjiang): Run benchmark and decide if TP. # TODO(xwjiang): Run benchmark and decide if TP.
...@@ -53,38 +67,56 @@ class LlavaMultiModalProjector(nn.Module): ...@@ -53,38 +67,56 @@ class LlavaMultiModalProjector(nn.Module):
return hidden_states return hidden_states
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
LlavaImageInputs = LlavaImagePixelInputs
def get_max_llava_image_tokens(ctx: InputContext): def get_max_llava_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(LlavaConfig) hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
return get_max_clip_image_tokens(vision_config) num_image_tokens = get_max_clip_image_tokens(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
msg = f"Unsupported vision config: {type(vision_config)}" num_image_tokens = get_max_siglip_image_tokens(vision_config)
raise NotImplementedError(msg) else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
strategy = hf_config.vision_feature_select_strategy
if strategy == "default":
return num_image_tokens - 1
elif strategy == "full":
return num_image_tokens
else:
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def dummy_data_for_llava(ctx: InputContext, seq_len: int): def dummy_data_for_llava(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaConfig) hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
num_images = mm_counts["image"]
image_feature_size = get_max_llava_image_tokens(ctx)
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip( seq_data = dummy_seq_data_for_clip(
vision_config, vision_config,
seq_len, seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(vision_config, num_images)
return seq_data, mm_data
elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
) )
mm_data = dummy_image_for_clip(vision_config) mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
...@@ -100,12 +132,49 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -100,12 +132,49 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
hf_config = ctx.get_hf_config(LlavaConfig) hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
image_feature_size = get_max_llava_image_tokens(ctx)
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip( return input_processor_for_clip(
model_config, model_config,
vision_config, vision_config,
llm_inputs, llm_inputs,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, SiglipVisionConfig):
return input_processor_for_siglip(
model_config,
vision_config,
llm_inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaConfig):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
) )
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
...@@ -116,7 +185,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -116,7 +185,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsVision): class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self, def __init__(self,
config: LlavaConfig, config: LlavaConfig,
...@@ -128,36 +197,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -128,36 +197,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel( self.vision_tower = _init_vision_tower(config)
config.vision_config, num_hidden_layers_override=num_hidden_layers)
self.multi_modal_projector = LlavaMultiModalProjector( self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act) projector_hidden_act=config.projector_hidden_act)
self.quant_config = quant_config self.language_model = init_vllm_registered_model(
self.language_model = LlamaModel(config.text_config, cache_config, config.text_config, cache_config, quant_config)
quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.text_config.hidden_size,
org_num_embeddings=self.language_model.org_vocab_size,
quant_config=quant_config)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.text_config.vocab_size,
logit_scale)
self.sampler = Sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size h = w = self.config.vision_config.image_size
...@@ -175,18 +223,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -175,18 +223,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]: self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None: if pixel_values is None and image_embeds is None:
return None return None
if not isinstance(pixel_values, torch.Tensor): if pixel_values is not None:
raise ValueError("Incorrect type of pixel values. " if not isinstance(pixel_values, torch.Tensor):
f"Got type: {type(pixel_values)}") raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaImagePixelInputs( return LlavaImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(pixel_values),
) )
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return LlavaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _select_image_features(self, image_features: torch.Tensor, *, def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor: strategy: str) -> torch.Tensor:
...@@ -198,8 +258,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -198,8 +258,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
raise ValueError(f"Unexpected select feature strategy: {strategy}") raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel, def _image_pixels_to_features(
pixel_values: torch.Tensor) -> torch.Tensor: self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since # NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower # this is already done inside the vision tower
...@@ -220,6 +283,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -220,6 +283,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def _process_image_input(self, def _process_image_input(self,
image_input: LlavaImageInputs) -> torch.Tensor: image_input: LlavaImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_tower is not None assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input) image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features) return self.multi_modal_projector(image_features)
...@@ -246,7 +313,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -246,7 +313,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`. 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
To reserve space in KV cache, we have to insert placeholder tokens To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends before they are inputted to the model, so the input processor prepends
additional image tokens (denoted as `32000`), resulting in: additional image tokens (denoted as `32000`), resulting in:
`[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618, `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
...@@ -264,7 +331,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -264,7 +331,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
input_ids: Flattened (concatenated) input_ids corresponding to a input_ids: Flattened (concatenated) input_ids corresponding to a
batch. batch.
pixel_values: The pixels in each input image. pixel_values: The pixels in each input image.
See also: See also:
:class:`LlavaImageInputs` :class:`LlavaImageInputs`
""" """
...@@ -272,9 +339,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -272,9 +339,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
if image_input is not None: if image_input is not None:
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_vision_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index) self.config.image_token_index)
...@@ -282,68 +350,47 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -282,68 +350,47 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
else: else:
inputs_embeds = None inputs_embeds = None
hidden_states = self.language_model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None, None,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
logits = self.logits_processor(self.lm_head, hidden_states, hidden_states: torch.Tensor,
sampling_metadata) sampling_metadata: SamplingMetadata,
return logits ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# only doing this for language model part for now. # prepare weight iterators for components
stacked_params_mapping = [ vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), # load vision encoder
("qkv_proj", "k_proj", "k"), vit_weights = filter_weights(vit_weights, "vision_tower")
("qkv_proj", "v_proj", "v"), self.vision_tower.load_weights(vit_weights)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), # load mlp projector
] mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
params_dict = dict(self.named_parameters()) mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in mlp_weights:
if "rotary_emb.inv_freq" in name: param = mlp_params_dict[name]
continue weight_loader = getattr(param, "weight_loader",
# post_layernorm is not needed in CLIPVisionModel default_weight_loader)
if "vision_model.post_layernorm" in name: weight_loader(param, loaded_weight)
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): # load llm backbone
if key_to_modify in name: llm_weights = filter_weights(llm_weights, "language_model")
name = name.replace(key_to_modify, new_key) self.language_model.load_weights(llm_weights)
use_default_weight_loading = False
if "vision" in name:
if self.vision_tower is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
use_default_weight_loading = True
if use_default_weight_loading and name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment