Commit af7f4372 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1

parents 5e19cdef 09c77926
......@@ -230,7 +230,7 @@ class GPTNeoXForCausalLM(nn.Module):
def __init__(
self,
config,
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
......@@ -243,6 +243,8 @@ class GPTNeoXForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.embed_out.weight = self.gpt_neox.embed_in.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -258,8 +260,11 @@ class GPTNeoXForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.embed_out, hidden_states,
sampling_metadata)
return logits
......
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
Union, overload, runtime_checkable)
from typing_extensions import TypeGuard
from typing_extensions import TypeIs
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.logger import init_logger
......@@ -10,12 +10,12 @@ logger = init_logger(__name__)
@runtime_checkable
class SupportsVision(Protocol):
"""The interface required for all vision language models (VLMs)."""
class SupportsMultiModal(Protocol):
"""The interface required for all multi-modal models."""
supports_vision: ClassVar[Literal[True]] = True
supports_multimodal: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports vision inputs.
A flag that indicates this model supports multi-modal inputs.
Note:
There is no need to redefine this flag if this class is in the
......@@ -29,30 +29,31 @@ class SupportsVision(Protocol):
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsVisionType(Protocol):
supports_vision: Literal[True]
class _SupportsMultiModalType(Protocol):
supports_multimodal: Literal[True]
def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
...
@overload
def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]:
def supports_multimodal(
model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
...
@overload
def supports_vision(model: object) -> TypeGuard[SupportsVision]:
def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
...
def supports_vision(
def supports_multimodal(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]:
) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
if isinstance(model, type):
return isinstance(model, _SupportsVisionType)
return isinstance(model, _SupportsMultiModalType)
return isinstance(model, SupportsVision)
return isinstance(model, SupportsMultiModal)
@runtime_checkable
......@@ -94,18 +95,18 @@ class _SupportsLoRAType(Protocol):
@overload
def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]:
def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]:
...
@overload
def supports_lora(model: object) -> TypeGuard[SupportsLoRA]:
def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
...
def supports_lora(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
result = _supports_lora(model)
if not result:
......@@ -137,7 +138,7 @@ def supports_lora(
def _supports_lora(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
if isinstance(model, type):
return isinstance(model, _SupportsLoRAType)
......@@ -172,18 +173,18 @@ class _HasInnerStateType(Protocol):
@overload
def has_inner_state(model: object) -> TypeGuard[HasInnerState]:
def has_inner_state(model: object) -> TypeIs[HasInnerState]:
...
@overload
def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]:
def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]:
...
def has_inner_state(
model: Union[Type[object], object]
) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]:
) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]:
if isinstance(model, type):
return isinstance(model, _HasInnerStateType)
......
......@@ -87,6 +87,7 @@ class InternLM2Attention(nn.Module):
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.key_value_groups = int(self.num_heads / self.num_kv_heads)
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
......@@ -120,6 +121,14 @@ class InternLM2Attention(nn.Module):
cache_config=cache_config,
quant_config=quant_config)
def split_qkv(self, qkv: torch.Tensor):
qkv = qkv.view(-1, self.num_kv_heads, self.key_value_groups + 2, 128)
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=2)
q = q.reshape(-1, self.q_size)
k = k.reshape(-1, self.kv_size)
v = v.reshape(-1, self.kv_size)
return q, k, v
def forward(
self,
positions: torch.Tensor,
......@@ -128,7 +137,7 @@ class InternLM2Attention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.wqkv(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k, v = self.split_qkv(qkv)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.wo(attn_output)
......@@ -264,6 +273,8 @@ class InternLM2ForCausalLM(nn.Module):
self.output = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.output.weight = self.model.tok_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -279,8 +290,11 @@ class InternLM2ForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.output, hidden_states,
sampling_metadata)
return logits
......@@ -319,24 +333,6 @@ class InternLM2ForCausalLM(nn.Module):
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
if "wqkv" in name:
config = self.config
kv_groups = (config.num_attention_heads //
config.num_key_value_heads)
head_dim = config.hidden_size // config.num_attention_heads
loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
head_dim,
loaded_weight.shape[-1])
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
dim=1)
wq = wq.reshape(-1, wq.shape[-1])
wk = wk.reshape(-1, wk.shape[-1])
wv = wv.reshape(-1, wv.shape[-1])
weight_loader = param.weight_loader
weight_loader(param, wq, 'q')
weight_loader(param, wk, 'k')
weight_loader(param, wv, 'v')
else:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
......@@ -5,7 +5,8 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import itertools
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
......@@ -18,18 +19,18 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import cached_get_tokenizer
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
from .interfaces import SupportsMultiModal
from .utils import (filter_weights, init_vllm_registered_model,
merge_multimodal_embeddings)
IMG_START = '<img>'
IMG_END = '</img>'
......@@ -38,9 +39,6 @@ IMG_CONTEXT = '<IMG_CONTEXT>'
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
MAX_IMAGE_FEATURE_SIZE_WIDTH = 3000
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 500
class InternVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
......@@ -53,6 +51,19 @@ class InternVLImagePixelInputs(TypedDict):
"""
class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
InternVLImageInputs = Union[InternVLImagePixelInputs,
InternVLImageEmbeddingInputs]
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
......@@ -84,11 +95,9 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
return best_ratio
def calculate_num_blocks(orig_width: int,
orig_height: int,
min_num=1,
max_num=6,
image_size=448):
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
max_num: int,
image_size: int) -> Tuple[int, int, int]:
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
......@@ -110,11 +119,9 @@ def calculate_num_blocks(orig_width: int,
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def dynamic_preprocess(image,
min_num=1,
max_num=6,
image_size=448,
use_thumbnail=False):
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
image_size: int,
use_thumbnail: int) -> List[Image.Image]:
orig_width, orig_height = image.size
blocks, target_width, target_height = calculate_num_blocks(
......@@ -138,12 +145,14 @@ def dynamic_preprocess(image,
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def image_to_pixel_values(image: Image.Image, input_size=448, max_num=6):
def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
max_num: int, use_thumbnail: bool) -> torch.Tensor:
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image,
min_num=min_num,
max_num=max_num,
image_size=input_size,
use_thumbnail=True,
max_num=max_num)
use_thumbnail=use_thumbnail)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
......@@ -157,14 +166,20 @@ def get_internvl_num_patches(image_size: int, patch_size: int,
def get_max_internvl_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config
use_thumbnail = hf_config.use_thumbnail
max_dynamic_patch = hf_config.max_dynamic_patch
if use_thumbnail:
max_dynamic_patch += 1
downsample_ratio = hf_config.downsample_ratio
image_size = vision_config.image_size
patch_size = vision_config.patch_size
downsample_ratio = hf_config.downsample_ratio
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)
return num_patches * 7
return num_patches * max_dynamic_patch
def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
......@@ -173,24 +188,32 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config
image_size = vision_config.image_size
patch_size = vision_config.patch_size
downsample_ratio = hf_config.downsample_ratio
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
width, height = image_data.size
num_blocks, _, _ = calculate_num_blocks(width, height)
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
max_num, image_size)
# add thumbnail image if num_blocks > 1
if hf_config.use_thumbnail and num_blocks > 1:
num_blocks += 1
image_feature_size = num_blocks * num_patches
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
image_feature_size = image_data.shape[0]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
image_size = vision_config.image_size
patch_size = vision_config.patch_size
downsample_ratio = hf_config.downsample_ratio
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
......@@ -198,8 +221,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
prompt_token_ids = llm_inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
image_prompt = IMG_START + IMG_CONTEXT * (num_blocks +
1) * num_patches + IMG_END
image_prompt = IMG_START + IMG_CONTEXT * image_feature_size + IMG_END
new_prompt = prompt.replace('<image>', image_prompt, 1)
new_prompt_token_ids = tokenizer.encode(new_prompt)
......@@ -209,8 +231,19 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
def input_mapper_for_internvl(ctx: InputContext, data: object):
hf_config = ctx.get_hf_config()
use_thumbnail = hf_config.use_thumbnail
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
image_size = hf_config.vision_config.image_size
if isinstance(data, Image.Image):
data = image_to_pixel_values(data)
data = image_to_pixel_values(data,
image_size,
min_num,
max_num,
use_thumbnail=use_thumbnail)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
......@@ -224,11 +257,13 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
})
def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
image_feature_size = get_max_internvl_image_tokens(ctx)
model_config = ctx.model_config
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
......@@ -236,14 +271,23 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
seq_data = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_images,
image_token_id=tokenizer.encode(IMG_CONTEXT,
add_special_tokens=False)[0],
image_feature_size_override=image_feature_size,
)
image_size = vision_config.image_size
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
max_image_width = max_num * image_size
max_image_height = min_num * image_size
mm_data = dummy_image_for_clip(
vision_config,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
num_images,
image_width_override=max_image_width,
image_height_override=max_image_height,
)
return seq_data, mm_data
......@@ -253,7 +297,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
class InternVLChatModel(nn.Module, SupportsVision):
class InternVLChatModel(nn.Module, SupportsMultiModal):
def __init__(self,
config: PretrainedConfig,
......@@ -283,10 +327,8 @@ class InternVLChatModel(nn.Module, SupportsVision):
self.vision_model = InternVisionModel(
config.vision_config, num_hidden_layers_override=num_hidden_layers)
llm_class = ModelRegistry.load_model_cls(
config.text_config.architectures[0])
self.language_model = llm_class(config.text_config, cache_config,
quant_config)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
......@@ -356,23 +398,49 @@ class InternVLChatModel(nn.Module, SupportsVision):
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[InternVLImagePixelInputs]:
self, **kwargs: object) -> Optional[InternVLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_token_id = kwargs.pop("image_token_id", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return InternVLImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
self.img_context_token_id = image_token_id[0]
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["data"])
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
return image_embeds
def forward(
self,
......@@ -387,10 +455,10 @@ class InternVLChatModel(nn.Module, SupportsVision):
if image_input is not None:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vit_embeds = self.extract_feature(image_input["data"])
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vit_embeds,
self.img_context_token_id)
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id)
input_ids = None
else:
inputs_embeds = None
......@@ -403,8 +471,11 @@ class InternVLChatModel(nn.Module, SupportsVision):
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
......@@ -415,24 +486,16 @@ class InternVLChatModel(nn.Module, SupportsVision):
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
prefix: str):
for name, loaded_weight in weights:
name = name.split(".")
if prefix == name.pop(0):
name = ".".join(name)
yield name, loaded_weight
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
# load vision encoder
vit_weights = self._filter_weights(vit_weights, "vision_model")
vit_weights = filter_weights(vit_weights, "vision_model")
self.vision_model.load_weights(vit_weights)
# load mlp projector
mlp_weights = self._filter_weights(mlp_weights, "mlp1")
mlp_weights = filter_weights(mlp_weights, "mlp1")
mlp_params_dict = dict(self.mlp1.named_parameters())
for name, loaded_weight in mlp_weights:
param = mlp_params_dict[name]
......@@ -441,5 +504,5 @@ class InternVLChatModel(nn.Module, SupportsVision):
weight_loader(param, loaded_weight)
# load llm backbone
llm_weights = self._filter_weights(llm_weights, "language_model")
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
......@@ -20,14 +20,14 @@
"""Inference-only Jais model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
......@@ -37,12 +37,14 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs import JAISConfig
from .utils import is_pp_missing_parameter, make_layers
class SwiGLUActivation(nn.Module):
......@@ -216,6 +218,7 @@ class JAISModel(nn.Module):
config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
......@@ -231,10 +234,15 @@ class JAISModel(nn.Module):
self.embeddings_scale = config.embeddings_scale
else:
self.embeddings_scale = config.mup_embeddings_scale
self.h = nn.ModuleList([
JAISBlock(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: JAISBlock(config=config,
cache_config=cache_config,
quant_config=quant_config),
prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
......@@ -243,19 +251,29 @@ class JAISModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
if self.wpe is not None:
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[IntermediateTensors, torch.Tensor]:
if get_pp_group().is_first_rank:
inputs_embeds = self.wte(input_ids)
if self.wpe is not None:
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
hidden_states = inputs_embeds
hidden_states *= torch.tensor(float(self.embeddings_scale),
dtype=hidden_states.dtype)
else:
hidden_states = inputs_embeds
hidden_states *= torch.tensor(float(self.embeddings_scale),
dtype=hidden_states.dtype)
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(len(self.h)):
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
......@@ -273,7 +291,11 @@ class JAISLMHeadModel(nn.Module):
self.config = config
self.quant_config = quant_config
self.transformer = JAISModel(config, cache_config, quant_config)
self.lm_head = self.transformer.wte
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
if hasattr(config, "width_scale"):
self.output_logits_scale = config.width_scale
else:
......@@ -290,17 +312,30 @@ class JAISLMHeadModel(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[IntermediateTensors, torch.Tensor]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample(
self,
logits: torch.Tensor,
......@@ -324,6 +359,10 @@ class JAISLMHeadModel(nn.Module):
continue
if not name.startswith("transformer."):
name = "transformer." + name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
......
......@@ -16,7 +16,6 @@ from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -249,37 +248,6 @@ class JambaMambaMixer(nn.Module):
return hidden_states
class JambaMLP(nn.Module):
def __init__(
self,
config: JambaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
hidden_act = config.hidden_act
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class JambaMoE(nn.Module):
def __init__(self,
......@@ -327,6 +295,21 @@ class JambaMoE(nn.Module):
return hidden_states.view(orig_shape)
class JambaMLP(JambaMoE):
def __init__(self,
config: JambaConfig,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__(config,
num_experts=1,
top_k=1,
params_dtype=params_dtype,
tp_size=tp_size,
quant_config=quant_config)
class JambaMambaDecoderLayer(nn.Module):
def __init__(self,
......@@ -609,12 +592,8 @@ class JambaForCausalLM(nn.Module, HasInnerState):
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
# Current step used indices
self.current_indices: List[int] = []
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
# Used as an input_buffer for the CUDA graph runs.
self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple()
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
......@@ -644,95 +623,148 @@ class JambaForCausalLM(nn.Module, HasInnerState):
batch_size = input_ids.shape[0]
if attn_metadata.prefill_metadata:
batch_size = len(request_ids_to_seq_ids)
(
current_seqlen_agnostic_cache,
indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
batch_size,
finished_requests_ids)
mamba_cache = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, batch_size, finished_requests_ids)
else:
# CUDA graph capturing runs
current_seqlen_agnostic_cache, indices = (
kwargs["seqlen_agnostic_capture_inputs"],
[],
)
self.current_indices = indices
mamba_cache = kwargs["seqlen_agnostic_capture_inputs"]
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata,
current_seqlen_agnostic_cache[0],
current_seqlen_agnostic_cache[1])
if "seqlen_agnostic_capture_inputs" not in kwargs:
self._copy_mamba_cache_by_indices(self.current_indices,
current_seqlen_agnostic_cache)
attn_metadata, mamba_cache[0],
mamba_cache[1])
return hidden_states
def _copy_mamba_cache_by_indices(
self, indices: List[int],
current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]):
for i, offset in enumerate(indices):
self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache)
def _swap_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, [to_index,from_index]] = \
cache_t[:, [from_index,to_index]]
def _copy_mamba_cache(self, index_to: int, index_from: int,
from_buffer: Tuple[torch.Tensor, torch.Tensor]):
def _copy_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer):
cache_t[:, index_to].copy_(from_buffer_t[:, index_from],
for cache_t in self.mamba_cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def _assign_seq_id_to_mamba_cache(self, cur_rid: str,
seqs_id: List[int]) -> List[int]:
indices_for_current_run = []
for seq_id in seqs_id:
if cur_rid not in self.mamba_cache_indices_mapping:
self.mamba_cache_indices_mapping[cur_rid] = {}
first_free_index = self._first_free_index_in_mamba_cache()
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = first_free_index
index_for_current_run = first_free_index
## case of decoding n>1, copy prefill cache to decoding indices
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
first_free_index = self._first_free_index_in_mamba_cache()
index_exist = list(seq_ids2indices.values())[0]
self._copy_mamba_cache(index_from=index_exist,
index_to=first_free_index,
from_buffer=self.mamba_cache)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = first_free_index
index_for_current_run = first_free_index
else:
index_for_current_run = self.mamba_cache_indices_mapping[
cur_rid][seq_id]
indices_for_current_run.append(index_for_current_run)
return indices_for_current_run
def _move_out_if_already_occupied(self, index: int,
all_occupied_indices: List[int]):
if index in all_occupied_indices:
first_free_index = self._first_free_index_in_mamba_cache()
# In case occupied, move the occupied to a new empty block
self._move_cache_index_and_mappings(from_index=index,
to_index=first_free_index)
def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
seq_id: int,
destination_index: int):
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
all_occupied_indices = self._get_all_occupied_indices()
if cur_rid not in self.mamba_cache_indices_mapping:
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
self.mamba_cache_indices_mapping[cur_rid] = {
seq_id: destination_index
}
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened now we only need to copy the already
# existing cache into the siblings seq_ids caches
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
index_exists = list(seq_ids2indices.values())[0]
# case of decoding n>1, copy prefill cache to decoding indices
self._copy_mamba_cache(from_index=index_exists,
to_index=destination_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = destination_index
else:
# already exists
cache_index_already_exists = self.mamba_cache_indices_mapping[
cur_rid][seq_id]
if cache_index_already_exists != destination_index:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self._swap_pair_indices_and_mappings(
from_index=cache_index_already_exists,
to_index=destination_index)
def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int,
finished_requests_ids: List[str]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
indices_for_current_run = []
for request_id, seqs_id in request_ids_to_seq_ids.items():
self, request_ids_to_seq_ids: Dict[str, list[int]],
batch_size: int, finished_requests_ids: List[str]):
running_indices = []
request_ids_to_seq_ids_flatten = [
(req_id, seq_id)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
for dest_index, (request_id,
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
if request_id in finished_requests_ids:
# Do not allocate cache for requests that run
# Do not allocate cache index for requests that run
# and finish right after
continue
indices_for_current_run += self._assign_seq_id_to_mamba_cache(
request_id, seqs_id)
## Pad the batch in case of running batch that was not captured via CG
padded_indices = indices_for_current_run.copy()
pad_index = self._first_free_index_in_mamba_cache()
self._assign_seq_id_to_mamba_cache_in_specific_dest(
request_id, seq_id, dest_index)
running_indices.append(dest_index)
for _ in range(batch_size - len(indices_for_current_run)):
padded_indices.append(pad_index)
self._clean_up_first_bs_blocks(batch_size, running_indices)
conv_state = self.mamba_cache[0][:, :batch_size]
temporal_state = self.mamba_cache[1][:, :batch_size]
conv_state = self.mamba_cache[0][:, padded_indices]
temporal_state = self.mamba_cache[1][:, padded_indices]
return (conv_state, temporal_state)
def _get_all_occupied_indices(self):
return [
cache_idx
for seq_ids2indices in self.mamba_cache_indices_mapping.values()
for cache_idx in seq_ids2indices.values()
]
return (conv_state, temporal_state), indices_for_current_run
def _clean_up_first_bs_blocks(self, batch_size: int,
indices_for_current_run: List[int]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices = set([range(batch_size)])
max_possible_batch_size = self.mamba_cache[0].shape[1]
for destination_index in destination_indices:
if destination_index in self._get_all_occupied_indices() and \
destination_index not in indices_for_current_run:
# move not running indices outside of the batch
all_other_indices = list(
range(batch_size, max_possible_batch_size))
first_avail_index = self._first_free_index_in_mamba_cache(
all_other_indices)
self._swap_indices(from_index=destination_index,
to_index=first_avail_index)
def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
self._copy_mamba_cache(from_index=from_index, to_index=to_index)
self._update_mapping_index(from_index=from_index, to_index=to_index)
def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
self._swap_mamba_cache(from_index=from_index, to_index=to_index)
self._swap_mapping_index(from_index=from_index, to_index=to_index)
def _swap_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
elif to_index == index:
seq_ids2index.update({seq_id: from_index})
def _update_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
return
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
......@@ -747,28 +779,9 @@ class JambaForCausalLM(nn.Module, HasInnerState):
self._release_mamba_cache(finished_requests_ids)
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
cg_batch_size = input_buffers['input_ids'].shape[0]
(
current_mamba_cache,
indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
cg_batch_size,
finished_requests_ids)
self.current_indices = indices
for input_buffer, current_cache_buffer in zip(
input_buffers["seqlen_agnostic_capture_inputs"],
current_mamba_cache):
input_buffer.copy_(current_cache_buffer, non_blocking=True)
def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant Mamba cache from the CUDA graph input_buffers
back to the JambaForCausalLM.mamba_cache after CUDA
graph replay run is done.
"""
self._copy_mamba_cache_by_indices(
self.current_indices,
input_buffers["seqlen_agnostic_capture_inputs"])
self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
cg_batch_size,
finished_requests_ids)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
......@@ -776,26 +789,25 @@ class JambaForCausalLM(nn.Module, HasInnerState):
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return tuple(buffer[:, :batch_size]
for buffer in self.mamba_gc_cache_buffer)
return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.mamba_cache_indices_mapping:
self.mamba_cache_indices_mapping.pop(req_id)
def _first_free_index_in_mamba_cache(self) -> int:
if self.mamba_cache:
def _first_free_index_in_mamba_cache(
self, indices_range: Optional[List[int]] = None) -> int:
assert self.mamba_cache is not None
if indices_range is None:
max_possible_batch_size = self.mamba_cache[0].shape[1]
occupied = [
id for seq_ids in self.mamba_cache_indices_mapping.values()
for id in seq_ids.values()
]
first_free_index = [
i not in occupied for i in range(max_possible_batch_size)
].index(True)
return first_free_index
return 0
indices_range = list(range(max_possible_batch_size))
all_occupied_indices = self._get_all_occupied_indices()
for i in indices_range:
if i not in all_occupied_indices:
return i
raise Exception("Couldn't find a free spot in the mamba cache! This"
"should never happen")
def _get_mamba_cache_shape(
self
......@@ -819,23 +831,24 @@ class JambaForCausalLM(nn.Module, HasInnerState):
[layer_type == "mamba" for layer_type in layers_type])
max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config else
max(_BATCH_SIZES_TO_CAPTURE)) + 10
max(_BATCH_SIZES_TO_CAPTURE) + 2)
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
assert conv_state_shape is not None and temporal_state_shape is not None
for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
buffer = (torch.empty(size=(mamba_layers, max_batch_size) +
conv_state_shape,
dtype=dtype,
device="cuda"),
torch.empty(size=(mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=dtype,
device="cuda"))
setattr(self, buffername, buffer)
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) +
conv_state_shape,
dtype=dtype,
device="cuda"),
torch.empty(size=(mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=dtype,
device="cuda"))
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......@@ -854,8 +867,6 @@ class JambaForCausalLM(nn.Module, HasInnerState):
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
......@@ -877,6 +888,10 @@ class JambaForCausalLM(nn.Module, HasInnerState):
if ".self_attn." in name:
name = name.replace(".self_attn", "")
if "feed_forward" in name and not _is_moe_layer(name):
## map MLP layers to expert with ID=0
name = name.replace("feed_forward", "feed_forward.experts.0")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
......@@ -891,10 +906,15 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
for (
param_name,
weight_name,
expert_id,
shard_id,
) in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
......@@ -913,3 +933,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def _is_moe_layer(name: str):
return any(
[experts_name in name for experts_name in [
"experts",
"router",
]])
......@@ -145,6 +145,7 @@ class LlamaAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
......@@ -153,12 +154,17 @@ class LlamaAttention(nn.Module):
prefix=f"{prefix}.o_proj",
)
is_neox_style = True
if quant_config is not None and quant_config.get_name() == "gguf":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
)
self.attn = Attention(self.num_heads,
self.head_dim,
......@@ -291,6 +297,7 @@ class LlamaModel(nn.Module):
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
self.embed_tokens = PPMissingLayer()
......@@ -444,8 +451,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
attn_metadata, intermediate_tensors)
return model_output
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
from transformers import CLIPVisionConfig, LlavaConfig
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_max_clip_image_tokens, input_processor_for_clip)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens,
input_processor_for_clip)
from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
merge_multimodal_embeddings)
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
# TODO(xwjiang): Run benchmark and decide if TP.
......@@ -53,38 +67,56 @@ class LlavaMultiModalProjector(nn.Module):
return hidden_states
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
LlavaImageInputs = LlavaImagePixelInputs
def get_max_llava_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
return get_max_clip_image_tokens(vision_config)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
num_image_tokens = get_max_clip_image_tokens(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_image_tokens = get_max_siglip_image_tokens(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
strategy = hf_config.vision_feature_select_strategy
if strategy == "default":
return num_image_tokens - 1
elif strategy == "full":
return num_image_tokens
else:
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
def dummy_data_for_llava(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
image_feature_size = get_max_llava_image_tokens(ctx)
if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(vision_config, num_images)
return seq_data, mm_data
elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(vision_config)
mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
......@@ -100,12 +132,49 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
image_feature_size = get_max_llava_image_tokens(ctx)
if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
model_config,
vision_config,
llm_inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, SiglipVisionConfig):
return input_processor_for_siglip(
model_config,
vision_config,
llm_inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaConfig):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
msg = f"Unsupported vision config: {type(vision_config)}"
......@@ -116,7 +185,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: LlavaConfig,
......@@ -128,36 +197,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
self.config = config
self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(
config.vision_config, num_hidden_layers_override=num_hidden_layers)
self.vision_tower = _init_vision_tower(config)
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.quant_config = quant_config
self.language_model = LlamaModel(config.text_config, cache_config,
quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.text_config.hidden_size,
org_num_embeddings=self.language_model.org_vocab_size,
quant_config=quant_config)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.text_config.vocab_size,
logit_scale)
self.sampler = Sampler()
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
......@@ -175,18 +223,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return LlavaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
......@@ -198,8 +258,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
......@@ -220,6 +283,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def _process_image_input(self,
image_input: LlavaImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)
......@@ -246,7 +313,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
before they are inputted to the model, so the input processor prepends
additional image tokens (denoted as `32000`), resulting in:
`[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
......@@ -264,7 +331,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each input image.
See also:
:class:`LlavaImageInputs`
"""
......@@ -272,9 +339,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_vision_embeddings(
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
......@@ -282,68 +350,47 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
else:
inputs_embeds = None
hidden_states = self.language_model(input_ids,
positions,
kv_caches,
attn_metadata,
None,
inputs_embeds=inputs_embeds)
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
None,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_tower is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
use_default_weight_loading = True
if use_default_weight_loading and name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment