Commit 6d2051cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev

parents 2c7f740a a2c71c54
...@@ -33,8 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -33,8 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
...@@ -43,7 +42,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -43,7 +42,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import JAISConfig from vllm.transformers_utils.configs import JAISConfig
from .utils import is_pp_missing_parameter, make_layers from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class SwiGLUActivation(nn.Module): class SwiGLUActivation(nn.Module):
...@@ -244,6 +245,9 @@ class JAISModel(nn.Module): ...@@ -244,6 +245,9 @@ class JAISModel(nn.Module):
) )
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward( def forward(
self, self,
...@@ -279,7 +283,7 @@ class JAISModel(nn.Module): ...@@ -279,7 +283,7 @@ class JAISModel(nn.Module):
return hidden_states return hidden_states
class JAISLMHeadModel(nn.Module): class JAISLMHeadModel(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -304,6 +308,8 @@ class JAISLMHeadModel(nn.Module): ...@@ -304,6 +308,8 @@ class JAISLMHeadModel(nn.Module):
self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
scale=self.output_logits_scale) scale=self.output_logits_scale)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -326,16 +332,6 @@ class JAISLMHeadModel(nn.Module): ...@@ -326,16 +332,6 @@ class JAISLMHeadModel(nn.Module):
sampling_metadata) sampling_metadata)
return logits return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
......
# coding=utf-8 # coding=utf-8
"""Inference-only Jamba model.""" """Inference-only Jamba model."""
from dataclasses import dataclass from typing import Iterable, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from torch.nn.parameter import Parameter
from transformers import JambaConfig from transformers import JambaConfig
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -25,31 +22,25 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( ...@@ -25,31 +22,25 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update) causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update) selective_scan_fn, selective_state_update)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.interfaces import HasInnerState composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size) _get_graph_batch_size)
from .interfaces import SupportsLoRA from .interfaces import HasInnerState, SupportsLoRA
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@dataclass
class MambaCacheParams:
is_prompt: bool = False
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class JambaMambaMixer(nn.Module): class JambaMambaMixer(nn.Module):
""" """
...@@ -62,10 +53,9 @@ class JambaMambaMixer(nn.Module): ...@@ -62,10 +53,9 @@ class JambaMambaMixer(nn.Module):
**selective** state spaces) **selective** state spaces)
""" """
def __init__(self, config: JambaConfig, layer_idx): def __init__(self, config: JambaConfig):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.ssm_state_size = config.mamba_d_state self.ssm_state_size = config.mamba_d_state
self.conv_kernel_size = config.mamba_d_conv self.conv_kernel_size = config.mamba_d_conv
...@@ -101,16 +91,6 @@ class JambaMambaMixer(nn.Module): ...@@ -101,16 +91,6 @@ class JambaMambaMixer(nn.Module):
bias=True, bias=True,
skip_bias_add=True) skip_bias_add=True)
def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
param.data.copy_(
loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
dim=0)[tp_rank])
def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
weight_loader(param, -torch.exp(loaded_weight.float()))
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
self.A = nn.Parameter( self.A = nn.Parameter(
torch.empty( torch.empty(
...@@ -120,8 +100,10 @@ class JambaMambaMixer(nn.Module): ...@@ -120,8 +100,10 @@ class JambaMambaMixer(nn.Module):
)) ))
self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
set_weight_attrs(self.D, {"weight_loader": weight_loader}) set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) a_weight_loader = composed_weight_loader(
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
self.intermediate_size, self.intermediate_size,
...@@ -138,42 +120,48 @@ class JambaMambaMixer(nn.Module): ...@@ -138,42 +120,48 @@ class JambaMambaMixer(nn.Module):
self.c_layernorm = RMSNorm(self.ssm_state_size, self.c_layernorm = RMSNorm(self.ssm_state_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
def mamba_forward(self, def forward(self, hidden_states: torch.Tensor,
hidden_states: torch.Tensor, attn_metadata: AttentionMetadata,
cache_params: MambaCacheParams = None): mamba_cache_params: MambaCacheParams):
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=1) hidden_states, gate = projected_states.chunk(2, dim=-2)
# 2. Convolution sequence transformation # 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2)) self.conv1d.weight.size(2))
if cache_params is not None and not cache_params.is_prompt:
hidden_states = causal_conv1d_update(
hidden_states.squeeze(-1),
cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
)
hidden_states = hidden_states.unsqueeze(-1)
else:
if cache_params is not None:
conv_states = nn.functional.pad(
hidden_states,
(self.conv_kernel_size - hidden_states.shape[-1], 0))
cache_params.conv_state.copy_(conv_states)
hidden_states, _ = causal_conv1d_fn( if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states = causal_conv1d_fn(
hidden_states, hidden_states,
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
activation=self.activation, activation=self.activation,
) conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=mamba_cache_params.state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation # 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C # 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0] ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
time_step, B, C = torch.split( time_step, B, C = torch.split(
ssm_parameters, ssm_parameters,
...@@ -184,72 +172,47 @@ class JambaMambaMixer(nn.Module): ...@@ -184,72 +172,47 @@ class JambaMambaMixer(nn.Module):
B = self.b_layernorm(B.contiguous()) B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous()) C = self.c_layernorm(C.contiguous())
discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x) # 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_proj.bias.float() if hasattr( time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None) self.dt_proj, "bias") else None)
if cache_params is not None and not cache_params.is_prompt:
scan_outputs = selective_state_update( if attn_metadata.query_start_loc is not None \
cache_params.ssm_state, and attn_metadata.context_lens_tensor is not None:
hidden_states[..., 0], scan_outputs = selective_scan_fn(
discrete_time_step[..., 0],
self.A,
B[:, 0],
C[:, 0],
self.D,
gate[..., 0],
time_proj_bias,
dt_softplus=True,
).unsqueeze(-1)
else:
scan_outputs, ssm_state = selective_scan_fn(
hidden_states, hidden_states,
mamba_cache_params.ssm_state,
discrete_time_step, discrete_time_step,
self.A, self.A,
B.transpose(1, 2), B.transpose(-2, -1),
C.transpose(1, 2), C.transpose(-2, -1),
self.D.float(), self.D.float(),
gate, gate,
time_proj_bias, time_proj_bias,
delta_softplus=True, delta_softplus=True,
return_last_state=True, cache_indices=mamba_cache_params.state_indices_tensor,
) has_initial_state=attn_metadata.context_lens_tensor > 0,
if ssm_state is not None and cache_params is not None: query_start_loc=attn_metadata.query_start_loc)
cache_params.ssm_state.copy_(ssm_state) else:
scan_outputs = selective_state_update(
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
B,
C,
self.D,
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection # 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0] contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
return contextualized_states return contextualized_states
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
):
if attn_metadata.prefill_metadata is not None:
offset = 0
for i, prompt_len in enumerate(
attn_metadata.prefill_metadata.seq_lens):
cache = MambaCacheParams(True,
conv_state=conv_state[i].unsqueeze(0),
ssm_state=ssm_state[i].unsqueeze(0))
hidden_states[offset:offset + prompt_len].copy_(
self.mamba_forward(hidden_states[offset:offset +
prompt_len].unsqueeze(0),
cache_params=cache)[0])
offset += prompt_len
else:
cache = MambaCacheParams(False,
conv_state=conv_state,
ssm_state=ssm_state)
hidden_states = self.mamba_forward(hidden_states.unsqueeze(1),
cache_params=cache)
hidden_states = hidden_states.squeeze(1)
return hidden_states
class JambaMoE(nn.Module): class JambaMoE(nn.Module):
...@@ -323,7 +286,7 @@ class JambaMambaDecoderLayer(nn.Module): ...@@ -323,7 +286,7 @@ class JambaMambaDecoderLayer(nn.Module):
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.config = config self.config = config
self.mamba = JambaMambaMixer(config, layer_idx) self.mamba = JambaMambaMixer(config)
num_experts = config.layers_num_experts[layer_idx] num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
...@@ -338,8 +301,7 @@ class JambaMambaDecoderLayer(nn.Module): ...@@ -338,8 +301,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
conv_state: torch.Tensor, mamba_cache_params: MambaCacheParams,
ssm_state: torch.Tensor,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
...@@ -349,8 +311,8 @@ class JambaMambaDecoderLayer(nn.Module): ...@@ -349,8 +311,8 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.mamba(hidden_states, attn_metadata, conv_state, hidden_states = self.mamba(hidden_states, attn_metadata,
ssm_state) mamba_cache_params)
# Fully Connected # Fully Connected
hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual) hidden_states, residual)
...@@ -507,17 +469,14 @@ class JambaModel(nn.Module): ...@@ -507,17 +469,14 @@ class JambaModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
conv_state: torch.Tensor, mamba_cache_params: MambaCacheParams,
ssm_state: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
kv_cache = None kv_cache = None
current_ssm_state = None layer_mamba_cache_params = None
current_conv_state = None
if isinstance(layer, JambaAttentionDecoderLayer): if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache = kv_caches[(i - self.config.attn_layer_offset) // kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
self.config.attn_layer_period] self.config.attn_layer_period]
...@@ -525,8 +484,8 @@ class JambaModel(nn.Module): ...@@ -525,8 +484,8 @@ class JambaModel(nn.Module):
current_state_layer = i - (1 + current_state_layer = i - (1 +
(i - self.config.attn_layer_offset) (i - self.config.attn_layer_offset)
// self.config.attn_layer_period) // self.config.attn_layer_period)
current_ssm_state = ssm_state[current_state_layer] layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
current_conv_state = conv_state[current_state_layer] current_state_layer)
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
...@@ -534,9 +493,7 @@ class JambaModel(nn.Module): ...@@ -534,9 +493,7 @@ class JambaModel(nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
residual=residual, residual=residual,
conv_state=current_conv_state, mamba_cache_params=layer_mamba_cache_params)
ssm_state=current_ssm_state,
)
hidden_states, _ = self.final_layernorm(hidden_states, residual) hidden_states, _ = self.final_layernorm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -571,8 +528,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): ...@@ -571,8 +528,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None, scheduler_config: Optional[SchedulerConfig] = None,
) -> None: ) -> None:
assert not scheduler_config.chunked_prefill_enabled, \
"Jamba currently does not support chunked prefill"
assert not cache_config.enable_prefix_caching, \ assert not cache_config.enable_prefix_caching, \
"Jamba currently does not support prefix caching" "Jamba currently does not support prefix caching"
...@@ -596,10 +551,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): ...@@ -596,10 +551,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
) )
# Used to track and store by the Mamba cache between steps. # Used to track and store by the Mamba cache between steps.
self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() self.mamba_cache: Optional[MambaCacheManager] = None
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -611,242 +564,51 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): ...@@ -611,242 +564,51 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs): **kwargs):
if not self.mamba_cache: if self.mamba_cache is None:
self._prepare_mamba_cache() max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config
if "seqlen_agnostic_capture_inputs" not in kwargs: else max(_BATCH_SIZES_TO_CAPTURE) + 2)
# We get here only on Prefill/Eager mode runs
assert all( layers_type = self.config.layers_block_type
key in kwargs num_mamba_layers = sum(
for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) [layer_type == "mamba" for layer_type in layers_type])
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] self.mamba_cache = MambaCacheManager(
finished_requests_ids = kwargs["finished_requests_ids"] self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
self._release_mamba_cache(finished_requests_ids) *self._get_mamba_cache_shape())
batch_size = input_ids.shape[0] (
if attn_metadata.prefill_metadata: mamba_cache_tensors,
batch_size = len(request_ids_to_seq_ids) state_indices_tensor,
mamba_cache = self._prepare_current_run_mamba_cache( ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
request_ids_to_seq_ids, batch_size, finished_requests_ids) **kwargs)
else: mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
# CUDA graph capturing runs mamba_cache_tensors[1],
mamba_cache = kwargs["seqlen_agnostic_capture_inputs"] state_indices_tensor)
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, mamba_cache[0], attn_metadata, mamba_cache_params)
mamba_cache[1])
return hidden_states return hidden_states
def _swap_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, [to_index,from_index]] = \
cache_t[:, [from_index,to_index]]
def _copy_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def _move_out_if_already_occupied(self, index: int,
all_occupied_indices: List[int]):
if index in all_occupied_indices:
first_free_index = self._first_free_index_in_mamba_cache()
# In case occupied, move the occupied to a new empty block
self._move_cache_index_and_mappings(from_index=index,
to_index=first_free_index)
def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
seq_id: int,
destination_index: int):
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
all_occupied_indices = self._get_all_occupied_indices()
if cur_rid not in self.mamba_cache_indices_mapping:
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
self.mamba_cache_indices_mapping[cur_rid] = {
seq_id: destination_index
}
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened now we only need to copy the already
# existing cache into the siblings seq_ids caches
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
index_exists = list(seq_ids2indices.values())[0]
# case of decoding n>1, copy prefill cache to decoding indices
self._copy_mamba_cache(from_index=index_exists,
to_index=destination_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = destination_index
else:
# already exists
cache_index_already_exists = self.mamba_cache_indices_mapping[
cur_rid][seq_id]
if cache_index_already_exists != destination_index:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self._swap_pair_indices_and_mappings(
from_index=cache_index_already_exists,
to_index=destination_index)
def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
batch_size: int, finished_requests_ids: List[str]):
running_indices = []
request_ids_to_seq_ids_flatten = [
(req_id, seq_id)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
for dest_index, (request_id,
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
if request_id in finished_requests_ids:
# Do not allocate cache index for requests that run
# and finish right after
continue
self._assign_seq_id_to_mamba_cache_in_specific_dest(
request_id, seq_id, dest_index)
running_indices.append(dest_index)
self._clean_up_first_bs_blocks(batch_size, running_indices)
conv_state = self.mamba_cache[0][:, :batch_size]
temporal_state = self.mamba_cache[1][:, :batch_size]
return (conv_state, temporal_state)
def _get_all_occupied_indices(self):
return [
cache_idx
for seq_ids2indices in self.mamba_cache_indices_mapping.values()
for cache_idx in seq_ids2indices.values()
]
def _clean_up_first_bs_blocks(self, batch_size: int,
indices_for_current_run: List[int]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices = range(batch_size)
max_possible_batch_size = self.mamba_cache[0].shape[1]
for destination_index in destination_indices:
if destination_index in self._get_all_occupied_indices() and \
destination_index not in indices_for_current_run:
# move not running indices outside of the batch
all_other_indices = list(
range(batch_size, max_possible_batch_size))
first_avail_index = self._first_free_index_in_mamba_cache(
all_other_indices)
self._swap_indices(from_index=destination_index,
to_index=first_avail_index)
def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
self._copy_mamba_cache(from_index=from_index, to_index=to_index)
self._update_mapping_index(from_index=from_index, to_index=to_index)
def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
self._swap_mamba_cache(from_index=from_index, to_index=to_index)
self._swap_mapping_index(from_index=from_index, to_index=to_index)
def _swap_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
elif to_index == index:
seq_ids2index.update({seq_id: from_index})
def _update_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
return
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
""" return self.mamba_cache.copy_inputs_before_cuda_graphs(
Copy the relevant Mamba cache into the CUDA graph input buffer input_buffers, **kwargs)
that was provided during the capture runs
(JambaForCausalLM.mamba_gc_cache_buffer).
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
cg_batch_size = input_buffers['input_ids'].shape[0]
self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
cg_batch_size,
finished_requests_ids)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int): def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
""" return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.mamba_cache_indices_mapping:
self.mamba_cache_indices_mapping.pop(req_id)
def _first_free_index_in_mamba_cache(
self, indices_range: Optional[List[int]] = None) -> int:
assert self.mamba_cache is not None
if indices_range is None:
max_possible_batch_size = self.mamba_cache[0].shape[1]
indices_range = list(range(max_possible_batch_size))
all_occupied_indices = self._get_all_occupied_indices()
for i in indices_range:
if i not in all_occupied_indices:
return i
raise Exception("Couldn't find a free spot in the mamba cache! This"
"should never happen")
def _get_mamba_cache_shape( def _get_mamba_cache_shape(
self self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]:
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
conv_state_shape = ( conv_state_shape = (
self.config.mamba_expand * hidden_size // world_size, self.config.mamba_expand * hidden_size // world_size,
self.config.mamba_d_conv, self.config.mamba_d_conv - 1,
) )
temporal_state_shape = ( temporal_state_shape = (
self.config.mamba_expand * self.config.hidden_size // world_size, self.config.mamba_expand * hidden_size // world_size,
self.config.mamba_d_state, self.config.mamba_d_state,
) )
return conv_state_shape, temporal_state_shape return conv_state_shape, temporal_state_shape
def _prepare_mamba_cache(self):
dtype = self.lm_head.weight.dtype
layers_type = self.config.layers_block_type
mamba_layers = sum(
[layer_type == "mamba" for layer_type in layers_type])
max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config else
max(_BATCH_SIZES_TO_CAPTURE) + 2)
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
assert conv_state_shape is not None and temporal_state_shape is not None
self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) +
conv_state_shape,
dtype=dtype,
device="cuda"),
torch.empty(size=(mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=dtype,
device="cuda"))
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -30,6 +30,7 @@ import os ...@@ -30,6 +30,7 @@ import os
import re import re
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
...@@ -39,8 +40,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -39,8 +40,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.pooler import Pooler, PoolingType
QuantizationConfig) from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale) get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
...@@ -49,12 +50,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -49,12 +50,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_hip from vllm.utils import is_hip
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
...@@ -77,12 +80,15 @@ class LlamaMLP(nn.Module): ...@@ -77,12 +80,15 @@ class LlamaMLP(nn.Module):
output_sizes=[intermediate_size] * 2, output_sizes=[intermediate_size] * 2,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj") prefix=f"{prefix}.gate_up_proj",
self.down_proj = RowParallelLinear(input_size=intermediate_size, )
output_size=hidden_size, self.down_proj = RowParallelLinear(
bias=bias, input_size=intermediate_size,
quant_config=quant_config, output_size=hidden_size,
prefix=f"{prefix}.down_proj") bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -166,12 +172,15 @@ class LlamaAttention(nn.Module): ...@@ -166,12 +172,15 @@ class LlamaAttention(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
) )
self.attn = Attention(self.num_heads,
self.head_dim, self.attn = Attention(
self.scaling, self.num_heads,
num_kv_heads=self.num_kv_heads, self.head_dim,
cache_config=cache_config, self.scaling,
quant_config=quant_config) num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
)
self.quant_method = None self.quant_method = None
if quant_config is not None: if quant_config is not None:
...@@ -260,12 +269,10 @@ class LlamaDecoderLayer(nn.Module): ...@@ -260,12 +269,10 @@ class LlamaDecoderLayer(nn.Module):
else: else:
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.self_attn( hidden_states = self.self_attn(positions=positions,
positions=positions, hidden_states=hidden_states,
hidden_states=hidden_states, kv_cache=kv_cache,
kv_cache=kv_cache, attn_metadata=attn_metadata)
attn_metadata=attn_metadata,
)
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
...@@ -274,6 +281,13 @@ class LlamaDecoderLayer(nn.Module): ...@@ -274,6 +281,13 @@ class LlamaDecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": 0,
"inputs_embeds": 0,
"intermediate_tensors": 0,
})
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
def __init__( def __init__(
...@@ -307,12 +321,27 @@ class LlamaModel(nn.Module): ...@@ -307,12 +321,27 @@ class LlamaModel(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix), prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -338,13 +367,9 @@ class LlamaModel(nn.Module): ...@@ -338,13 +367,9 @@ class LlamaModel(nn.Module):
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(positions, hidden_states,
positions, kv_caches[i - self.start_layer],
hidden_states, attn_metadata, residual)
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
...@@ -355,153 +380,6 @@ class LlamaModel(nn.Module): ...@@ -355,153 +380,6 @@ class LlamaModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class LlamaForCausalLM(nn.Module, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
"lm_head"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
"wo": "o_proj",
"attention_norm": "input_layernorm",
"feed_forward": "mlp",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm"
}
def __init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.lora_config = lora_config
self.model = LlamaModel(config,
cache_config,
quant_config,
lora_config=lora_config,
prefix="model")
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = Sampler()
else:
self.lm_head = PPMissingLayer()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
...@@ -513,8 +391,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -513,8 +391,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if ("rotary_emb.cos_cached" in name if ("rotary_emb.cos_cached" in name
...@@ -522,11 +398,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -522,11 +398,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if scale_name := get_compressed_tensors_cache_scale(name): if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization # Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name] param = params_dict[scale_name]
...@@ -535,7 +406,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -535,7 +406,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
loaded_weight = loaded_weight[0] loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
...@@ -566,7 +437,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -566,7 +437,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn and self.quant_method is None : if self.use_llama_nn and self.quant_method is None :
lay_key_words = [ lay_key_words = [
...@@ -574,7 +445,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -574,7 +445,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"self_attn.o_proj.weight", "self_attn.o_proj.weight",
"mlp.gate_up_proj.weight", "mlp.gate_up_proj.weight",
"mlp.down_proj.weight", "mlp.down_proj.weight",
"lm_head.weight" # "lm_head.weight"
] ]
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
...@@ -656,7 +527,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -656,7 +527,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
k=weight_data.shape[0] k=weight_data.shape[0]
_weight=weight_data.T.contiguous().reshape(k,-1) _weight=weight_data.T.contiguous().reshape(k,-1)
weight_data.data.copy_(_weight) weight_data.data.copy_(_weight)
# If this function is called, it should always initialize KV cache scale # If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should # factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state # make sure to leave KV cache scale factors in a known good (dummy) state
...@@ -667,8 +538,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -667,8 +538,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
quantization_param_path, tp_rank, tp_size, quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers, self.config.num_hidden_layers,
self.config.__class__.model_type): self.config.__class__.model_type):
if not isinstance(self.model.layers[layer_idx], nn.Identity): if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn layer_self_attn = self.layers[layer_idx].self_attn
if is_hip(): if is_hip():
# The scaling factor convention we are assuming is # The scaling factor convention we are assuming is
...@@ -682,13 +553,161 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -682,13 +553,161 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
raise RuntimeError("Self attention has no KV cache scaling " raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!") "factor attribute!")
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
"lm_head"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings"
}
embedding_padding_modules = ["lm_head"]
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
"wo": "o_proj",
"attention_norm": "input_layernorm",
"feed_forward": "mlp",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm"
}
def __init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.lora_config = lora_config
self.model = LlamaModel(config,
cache_config,
quant_config,
lora_config=lora_config,
prefix="model")
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else
lora_config.lora_vocab_padding_size),
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = Sampler()
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
loader.load_weights(
self.maybe_remap_mistral(name, loaded_weight)
for name, loaded_weight in weights)
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
# This function is used to remap the mistral format as # This function is used to remap the mistral format as
# used by Mistral and Llama <=2 # used by Mistral and Llama <=2
def maybe_remap_mistral( def maybe_remap_mistral(
self, name: str, self,
loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]: name: str,
loaded_weight: torch.Tensor,
) -> Tuple[str, torch.Tensor]:
def permute(w, n_heads): def permute(w: torch.Tensor, n_heads: int):
attn_in = self.config.head_dim * n_heads attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size attn_out = self.config.hidden_size
...@@ -711,3 +730,52 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -711,3 +730,52 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
name = name.replace(item, mapping[item]) name = name.replace(item, mapping[item])
return name, loaded_weight return name, loaded_weight
class LlamaEmbeddingModel(nn.Module, SupportsPP):
"""
A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(
self,
**kwargs,
) -> None:
super().__init__()
self.model = LlamaModel(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import PoolerOutput
class LlamaEmbeddingModel(nn.Module):
"""A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(
self,
**kwargs,
) -> None:
super().__init__()
self.model = LlamaModel(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model.forward(input_ids, positions, kv_caches,
attn_metadata, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.model.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -8,11 +9,10 @@ from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig ...@@ -8,11 +9,10 @@ from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -21,12 +21,12 @@ from vllm.utils import is_list_of ...@@ -21,12 +21,12 @@ from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip, from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens, dummy_seq_data_for_clip, get_max_clip_image_tokens,
input_processor_for_clip) input_processor_for_clip)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens, dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip) input_processor_for_siglip)
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
init_vllm_registered_model, merge_multimodal_embeddings) merge_multimodal_embeddings)
class LlavaImagePixelInputs(TypedDict): class LlavaImagePixelInputs(TypedDict):
...@@ -125,10 +125,10 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, ...@@ -125,10 +125,10 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
raise NotImplementedError(msg) raise NotImplementedError(msg)
def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
model_config = ctx.model_config model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaConfig) hf_config = ctx.get_hf_config(LlavaConfig)
...@@ -151,7 +151,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -151,7 +151,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
return input_processor_for_clip( return input_processor_for_clip(
model_config, model_config,
vision_config, vision_config,
llm_inputs, inputs,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
...@@ -159,7 +159,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -159,7 +159,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
return input_processor_for_siglip( return input_processor_for_siglip(
model_config, model_config,
vision_config, vision_config,
llm_inputs, inputs,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
...@@ -198,7 +198,7 @@ def _init_vision_tower(hf_config: LlavaConfig): ...@@ -198,7 +198,7 @@ def _init_vision_tower(hf_config: LlavaConfig):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(self,
config: LlavaConfig, config: LlavaConfig,
...@@ -220,6 +220,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -220,6 +220,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config) config.text_config, cache_config, quant_config)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size h = w = self.config.vision_config.image_size
expected_dims = (3, h, w) expected_dims = (3, h, w)
...@@ -315,7 +325,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -315,7 +325,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LLaVA-1.5. """Run forward pass for LLaVA-1.5.
One key thing to understand is the `input_ids` already accounts for the One key thing to understand is the `input_ids` already accounts for the
...@@ -351,26 +361,32 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -351,26 +361,32 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
See also: See also:
:class:`LlavaImageInputs` :class:`LlavaImageInputs`
""" """
image_input = self._parse_and_validate_image_input(**kwargs) if intermediate_tensors is not None:
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
input_ids = None input_ids = None
else:
inputs_embeds = None inputs_embeds = None
else:
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
else:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
input_ids = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
...@@ -391,19 +407,5 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -391,19 +407,5 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision encoder
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -11,10 +12,9 @@ from typing_extensions import NotRequired ...@@ -11,10 +12,9 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -23,13 +23,13 @@ from vllm.utils import is_list_of ...@@ -23,13 +23,13 @@ from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip, from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_clip_image_feature_size, dummy_seq_data_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip) get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsPP
from .llava import LlavaMultiModalProjector from .llava import LlavaMultiModalProjector
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size, dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
init_vllm_registered_model, merge_multimodal_embeddings) merge_multimodal_embeddings)
# Result in the max possible feature size (2x2 grid of 336x336px tiles) # Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
...@@ -201,10 +201,11 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, ...@@ -201,10 +201,11 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
raise NotImplementedError(msg) raise NotImplementedError(msg)
def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_llava_next(ctx: InputContext,
multi_modal_data = llm_inputs.get("multi_modal_data") inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
model_config = ctx.model_config model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaNextConfig) hf_config = ctx.get_hf_config(LlavaNextConfig)
...@@ -239,7 +240,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -239,7 +240,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
return input_processor_for_clip( return input_processor_for_clip(
model_config, model_config,
vision_config, vision_config,
llm_inputs, inputs,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
...@@ -247,7 +248,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -247,7 +248,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
return input_processor_for_siglip( return input_processor_for_siglip(
model_config, model_config,
vision_config, vision_config,
llm_inputs, inputs,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
...@@ -286,7 +287,8 @@ def _init_vision_tower(hf_config: LlavaNextConfig): ...@@ -286,7 +287,8 @@ def _init_vision_tower(hf_config: LlavaNextConfig):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self, def __init__(self,
config: LlavaNextConfig, config: LlavaNextConfig,
...@@ -300,6 +302,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -300,6 +302,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = _init_vision_tower(config) self.vision_tower = _init_vision_tower(config)
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
self.multi_modal_projector = LlavaMultiModalProjector( self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
...@@ -308,8 +312,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -308,8 +312,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config) config.text_config, cache_config, quant_config)
self.image_newline = nn.Parameter( self.make_empty_intermediate_tensors = (
torch.empty(config.text_config.hidden_size)) self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, ) expected_dims = (2, )
...@@ -542,7 +553,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -542,7 +553,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LlaVA-NeXT. """Run forward pass for LlaVA-NeXT.
One key thing to understand is the `input_ids` already accounts for the One key thing to understand is the `input_ids` already accounts for the
...@@ -587,26 +598,30 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -587,26 +598,30 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
See also: See also:
:class:`LlavaNextImageInputs` :class:`LlavaNextImageInputs`
""" """
image_input = self._parse_and_validate_image_input(**kwargs) if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None: if image_input is not None:
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings( inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids) input_ids)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index) self.config.image_token_index)
input_ids = None input_ids = None
else: else:
inputs_embeds = None inputs_embeds = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
...@@ -627,27 +642,5 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -627,27 +642,5 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision encoder
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load newline
for name, loaded_weight in weights_group["image_newline"]:
assert name == ""
param = self.image_newline
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
import math import math
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -10,12 +11,11 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig, ...@@ -10,12 +11,11 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -25,10 +25,10 @@ from vllm.sequence import IntermediateTensors ...@@ -25,10 +25,10 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip) dummy_seq_data_for_siglip)
from .utils import (group_weights_with_prefix, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings) merge_multimodal_embeddings)
# For profile run # For profile run
...@@ -140,10 +140,10 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int, ...@@ -140,10 +140,10 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int,
def input_processor_for_llava_next_video(ctx: InputContext, def input_processor_for_llava_next_video(ctx: InputContext,
llm_inputs: LLMInputs): inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data: if multi_modal_data is None or "video" not in multi_modal_data:
return llm_inputs return inputs
video_data = multi_modal_data["video"] video_data = multi_modal_data["video"]
model_config = ctx.model_config model_config = ctx.model_config
...@@ -161,15 +161,15 @@ def input_processor_for_llava_next_video(ctx: InputContext, ...@@ -161,15 +161,15 @@ def input_processor_for_llava_next_video(ctx: InputContext,
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
llm_inputs.get("prompt"), inputs.get("prompt"),
llm_inputs["prompt_token_ids"], inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index, placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size, repeat_count=video_feature_size,
) )
return LLMInputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
elif is_list_of(video_data, np.ndarray): elif is_list_of(video_data, np.ndarray):
raise NotImplementedError( raise NotImplementedError(
...@@ -267,7 +267,8 @@ class LlavaNextMultiModalProjector(nn.Module): ...@@ -267,7 +267,8 @@ class LlavaNextMultiModalProjector(nn.Module):
"video", get_max_llava_next_video_tokens) "video", get_max_llava_next_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video)
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal): class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self, def __init__(self,
config: LlavaNextVideoConfig, config: LlavaNextVideoConfig,
...@@ -281,13 +282,23 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -281,13 +282,23 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
# Initialize the vision tower only up to the required feature layer # Initialize the vision tower only up to the required feature layer
self.vision_tower = _init_vision_tower(config) self.vision_tower = _init_vision_tower(config)
self.vision_resampler = LlavaNextVideoPooler(config)
self.multi_modal_projector = LlavaNextMultiModalProjector( self.multi_modal_projector = LlavaNextMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act) projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config) config.text_config, cache_config, quant_config)
self.vision_resampler = LlavaNextVideoPooler(config)
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def _validate_video_pixel_values( def _validate_video_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]] self, data: Union[torch.Tensor, List[torch.Tensor]]
...@@ -397,34 +408,36 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -397,34 +408,36 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LlaVA-NeXT-Video. """Run forward pass for LlaVA-NeXT-Video.
Args: Args:
input_ids: Flattened (concatenated) input_ids corresponding to a input_ids: Flattened (concatenated) input_ids corresponding to a
batch. batch.
pixel_values_videos: Pixels in each frames for each input videos. pixel_values_videos: Pixels in each frames for each input videos.
""" """
video_input = self._parse_and_validate_video_input(**kwargs) if intermediate_tensors is not None:
# merge video embeddings into input embeddings
if video_input is not None:
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = self.language_model \
.model.get_input_embeddings(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
input_ids = None input_ids = None
else:
inputs_embeds = None inputs_embeds = None
else:
video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is not None:
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = self.language_model \
.model.get_input_embeddings(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
...@@ -445,19 +458,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -445,19 +458,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(
weights_group = group_weights_with_prefix(weights) self,
# This model doesn't support images for now
# load vision encoder ignore_unexpected_prefixes=["image_newline"],
self.vision_tower.load_weights(weights_group["vision_tower"]) )
loader.load_weights(weights)
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
import math import math
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -14,13 +15,11 @@ from typing_extensions import NotRequired ...@@ -14,13 +15,11 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
from vllm.logger import init_logger token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
...@@ -31,14 +30,12 @@ from vllm.utils import is_list_of ...@@ -31,14 +30,12 @@ from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_seq_data_for_clip, from .clip import (CLIPVisionModel, dummy_seq_data_for_clip,
dummy_video_for_clip, get_clip_image_feature_size, dummy_video_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip) get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip, from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip, get_siglip_image_feature_size, dummy_video_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
init_vllm_registered_model, merge_multimodal_embeddings) merge_multimodal_embeddings)
logger = init_logger(__name__)
# Result in the max possible feature size (2x2 grid of 336x336px tiles) # Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
...@@ -253,10 +250,10 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int, ...@@ -253,10 +250,10 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
def input_processor_when_multimodal_input_image(ctx: InputContext, def input_processor_when_multimodal_input_image(ctx: InputContext,
llm_inputs: LLMInputs): inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
model_config = ctx.model_config model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaOnevisionConfig) hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
...@@ -291,7 +288,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext, ...@@ -291,7 +288,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext,
return input_processor_for_clip( return input_processor_for_clip(
model_config, model_config,
vision_config, vision_config,
llm_inputs, inputs,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
...@@ -299,7 +296,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext, ...@@ -299,7 +296,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext,
return input_processor_for_siglip( return input_processor_for_siglip(
model_config, model_config,
vision_config, vision_config,
llm_inputs, inputs,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
...@@ -309,10 +306,10 @@ def input_processor_when_multimodal_input_image(ctx: InputContext, ...@@ -309,10 +306,10 @@ def input_processor_when_multimodal_input_image(ctx: InputContext,
def input_processor_when_multimodal_input_video(ctx: InputContext, def input_processor_when_multimodal_input_video(ctx: InputContext,
llm_inputs: LLMInputs): inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data: if multi_modal_data is None or "video" not in multi_modal_data:
return llm_inputs return inputs
video_data = multi_modal_data["video"] video_data = multi_modal_data["video"]
model_config = ctx.model_config model_config = ctx.model_config
...@@ -327,15 +324,15 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, ...@@ -327,15 +324,15 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
llm_inputs.get("prompt"), inputs.get("prompt"),
llm_inputs["prompt_token_ids"], inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index, placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size, repeat_count=video_feature_size,
) )
return LLMInputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
elif is_list_of(video_data, np.ndarray): elif is_list_of(video_data, np.ndarray):
raise NotImplementedError( raise NotImplementedError(
...@@ -346,15 +343,15 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, ...@@ -346,15 +343,15 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
def input_processor_for_llava_onevision(ctx: InputContext, def input_processor_for_llava_onevision(ctx: InputContext,
llm_inputs: LLMInputs): inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or ("video" not in multi_modal_data if multi_modal_data is None or ("video" not in multi_modal_data
and "image" not in multi_modal_data): and "image" not in multi_modal_data):
return llm_inputs return inputs
if "image" in multi_modal_data: if "image" in multi_modal_data:
return input_processor_when_multimodal_input_image(ctx, llm_inputs) return input_processor_when_multimodal_input_image(ctx, inputs)
if "video" in multi_modal_data: if "video" in multi_modal_data:
return input_processor_when_multimodal_input_video(ctx, llm_inputs) return input_processor_when_multimodal_input_video(ctx, inputs)
msg = "Unsupported multi data type" msg = "Unsupported multi data type"
raise NotImplementedError(msg) raise NotImplementedError(msg)
...@@ -414,7 +411,8 @@ class LlavaOnevisionMultiModalProjector(nn.Module): ...@@ -414,7 +411,8 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
"video", get_max_llava_onevision_video_tokens) "video", get_max_llava_onevision_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision)
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal): class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self, def __init__(self,
config: LlavaOnevisionConfig, config: LlavaOnevisionConfig,
...@@ -434,6 +432,16 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -434,6 +432,16 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
self.image_newline = nn.Parameter( self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size)) torch.empty(config.text_config.hidden_size))
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, ) expected_dims = (2, )
...@@ -805,39 +813,42 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -805,39 +813,42 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LlaVA-Onevision. """Run forward pass for LlaVA-Onevision.
Args: Args:
input_ids: Flattened (concatenated) input_ids corresponding to a input_ids: Flattened (concatenated) input_ids corresponding to a
batch. batch.
pixel_values_videos: Pixels in each frames for each input videos. pixel_values_videos: Pixels in each frames for each input videos.
""" """
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if intermediate_tensors is not None:
# merge video embeddings into input embeddings
if modalities:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
if "images" in modalities:
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
if "videos" in modalities:
video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
input_ids = None input_ids = None
else:
inputs_embeds = None inputs_embeds = None
else:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if modalities:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
if "images" in modalities:
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
if "videos" in modalities:
video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
...@@ -858,19 +869,5 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -858,19 +869,5 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision encoder
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
# coding=utf-8
"""PyTorch MAMBA model."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import MambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)
KVCache = Tuple[torch.Tensor, torch.Tensor]
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class MambaMixer(nn.Module):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
for why A isn't selective) ∆, B, C are input-dependent
(this is a key difference between Mamba and the linear time
invariant S4, and is why Mamba is called
**selective** state spaces)
"""
def __init__(self, config: MambaConfig, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = config.intermediate_size
self.time_step_rank = int(config.time_step_rank)
self.conv1d = ColumnParallelLinear(
input_size=self.conv_kernel_size,
output_size=self.intermediate_size,
bias=config.use_conv_bias,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = MergedColumnParallelLinear(self.hidden_size,
[self.intermediate_size] * 2,
bias=config.use_bias)
# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
self.intermediate_size,
self.time_step_rank + self.ssm_state_size * 2,
bias=False,
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self.dt_proj = ColumnParallelLinear(self.time_step_rank,
self.intermediate_size,
bias=True,
skip_bias_add=True)
tp_size = get_tensor_model_parallel_world_size()
self.A = nn.Parameter(
torch.empty(
self.intermediate_size // tp_size,
self.ssm_state_size,
dtype=torch.float32,
))
self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
a_weight_loader = composed_weight_loader(
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
self.out_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=config.use_bias,
input_is_parallel=True,
)
self.activation = config.hidden_act
def forward(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states = causal_conv1d_fn(
hidden_states,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=mamba_cache_params.state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
time_step, B, C = torch.split(
ssm_parameters,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1,
)
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None)
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
mamba_cache_params.ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
C.transpose(-2, -1),
self.D.float(),
gate,
time_proj_bias,
delta_softplus=True,
cache_indices=mamba_cache_params.state_indices_tensor,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
B,
C,
self.D,
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
return contextualized_states
class MambaDecoderLayer(nn.Module):
def __init__(self,
config: MambaConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.mixer = MambaMixer(config, layer_idx)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
**kwargs,
):
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, attn_metadata,
mamba_cache_params)
return hidden_states, residual
class MambaModel(nn.Module):
def __init__(
self,
config: MambaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embeddings = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
decoder_layers = []
for i in range(config.num_hidden_layers):
decoder_layers.append(
MambaDecoderLayer(config,
layer_idx=i,
cache_config=cache_config,
quant_config=quant_config))
self.layers = nn.ModuleList(decoder_layers)
self.norm_f = RMSNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
) -> torch.Tensor:
hidden_states = self.embeddings(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx(i))
hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def __init__(
self,
config: MambaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
) -> None:
assert not cache_config.enable_prefix_caching, \
"Mamba does not support prefix caching"
super().__init__()
self.config = config
self.scheduler_config = scheduler_config
self.backbone = MambaModel(config,
cache_config=cache_config,
quant_config=quant_config,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = self.backbone.embeddings
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs):
if self.mamba_cache is None:
max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, self.config.num_hidden_layers,
max_batch_size, *self._get_mamba_cache_shape())
(
mamba_cache_tensors,
state_indices_tensor,
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
**kwargs)
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
mamba_cache_tensors[1],
state_indices_tensor)
hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_params)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
conv_state_shape = (
self.config.intermediate_size // world_size,
self.config.conv_kernel - 1,
)
temporal_state_shape = (
self.config.intermediate_size // world_size,
self.config.state_size,
)
return conv_state_shape, temporal_state_shape
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "A_log" in name:
name = name.replace("A_log", "A")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
from dataclasses import dataclass
from typing import Dict, List
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.utils import PAD_SLOT_ID
@dataclass
class MambaCacheParams:
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()
state_indices_tensor: torch.Tensor = torch.Tensor()
def at_layer_idx(self, layer_idx):
return MambaCacheParams(self.conv_state[layer_idx],
self.ssm_state[layer_idx],
self.state_indices_tensor)
class MambaCacheManager:
def __init__(self, dtype, num_mamba_layers, max_batch_size,
conv_state_shape, temporal_state_shape):
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
conv_state_shape,
dtype=dtype,
device="cuda")
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=dtype,
device="cuda")
self.mamba_cache = (conv_state, temporal_state)
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size))
def current_run_tensors(self, input_ids: torch.Tensor,
attn_metadata: AttentionMetadata, **kwargs):
"""
Return the tensors for the current run's conv and ssm state.
"""
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, finished_requests_ids)
state_indices_tensor = torch.as_tensor(state_indices,
dtype=torch.int32,
device="cuda")
mamba_cache_tensors = self.mamba_cache
else:
# CUDA graph capturing runs
(mamba_cache_tensors,
state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"]
return (mamba_cache_tensors, state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
assert "seqlen_agnostic_capture_inputs" in input_buffers
_, input_state_indices_buffer = input_buffers[
"seqlen_agnostic_capture_inputs"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, finished_requests_ids)
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
state_indices)
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
input_state_indices_buffer.copy_(
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
return (self.mamba_cache, state_indices_tensor)
def _copy_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
finished_requests_ids) -> int:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if cur_rid in finished_requests_ids:
# set as pad, do not allocate destination index
return PAD_SLOT_ID
elif cur_rid not in self.mamba_cache_indices_mapping:
destination_index = self.free_cache_indices.pop()
self.mamba_cache_indices_mapping[cur_rid] = {
seq_id: destination_index
}
return destination_index
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists = next(iter(seq_ids2indices.values()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index = self.free_cache_indices.pop()
self._copy_mamba_cache(from_index=index_exists,
to_index=destination_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = destination_index
return destination_index
else:
# already exists
return self.mamba_cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]) -> List[int]:
return [
self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
def _release_finished_requests(self,
finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.mamba_cache_indices_mapping:
for seq_id in self.mamba_cache_indices_mapping[req_id]:
self.free_cache_indices.append(
self.mamba_cache_indices_mapping[req_id][seq_id])
self.mamba_cache_indices_mapping.pop(req_id)
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights.""" """Inference-only MiniCPM model compatible with HuggingFace weights."""
import math import math
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -30,10 +30,10 @@ from transformers import PretrainedConfig ...@@ -30,10 +30,10 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -41,8 +41,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -41,8 +41,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -52,7 +51,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -52,7 +51,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class MiniCPMMoE(nn.Module): class MiniCPMMoE(nn.Module):
...@@ -151,6 +152,7 @@ class MiniCPMMLP(nn.Module): ...@@ -151,6 +152,7 @@ class MiniCPMMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
hidden_act_param: float,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -162,10 +164,13 @@ class MiniCPMMLP(nn.Module): ...@@ -162,10 +164,13 @@ class MiniCPMMLP(nn.Module):
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config) quant_config=quant_config)
if hidden_act != "silu": if hidden_act == "silu":
self.act_fn = SiluAndMul()
elif hidden_act == "fatrelu":
self.act_fn = FatreluAndMul(threshold=hidden_act_param)
else:
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu and fatrelu are supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x):
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
...@@ -264,7 +269,7 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -264,7 +269,7 @@ class MiniCPMDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
...@@ -303,6 +308,7 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -303,6 +308,7 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=self.config.intermediate_size, intermediate_size=self.config.intermediate_size,
hidden_act=self.config.hidden_act, hidden_act=self.config.hidden_act,
hidden_act_param=getattr(self.config, "hidden_act_param", 0.),
quant_config=self.quant_config, quant_config=self.quant_config,
) )
else: else:
...@@ -346,10 +352,11 @@ class MiniCPMModel(nn.Module): ...@@ -346,10 +352,11 @@ class MiniCPMModel(nn.Module):
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -365,15 +372,24 @@ class MiniCPMModel(nn.Module): ...@@ -365,15 +372,24 @@ class MiniCPMModel(nn.Module):
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self._init_layers() self._init_layers(prefix, config, cache_config, quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size))
def _init_layers(self): def _init_layers(
self.layers = nn.ModuleList([ self,
MiniCPMDecoderLayer(self.config, self.cache_config, prefix: str,
self.quant_config) config: PretrainedConfig,
for _ in range(self.config.num_hidden_layers) cache_config: Optional[CacheConfig],
]) quant_config: Optional[QuantizationConfig],
):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MiniCPMDecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
embedding = self.embed_tokens(input_ids) embedding = self.embed_tokens(input_ids)
...@@ -387,27 +403,36 @@ class MiniCPMModel(nn.Module): ...@@ -387,27 +403,36 @@ class MiniCPMModel(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is not None: if get_pp_group().is_first_rank:
hidden_states = inputs_embeds if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = intermediate_tensors["hidden_states"]
residual = None residual = intermediate_tensors["residual"]
for i in range(len(self.layers)): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
class MiniCPMForCausalLM(nn.Module, SupportsLoRA): class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -454,22 +479,25 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -454,22 +479,25 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
unpadded_vocab_size = config.vocab_size unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size unpadded_vocab_size += lora_config.lora_extra_vocab_size
if not self.config.tie_word_embeddings: self.lm_head = ParallelLMHead(
self.lm_head = ParallelLMHead( unpadded_vocab_size,
unpadded_vocab_size, config.hidden_size,
config.hidden_size, org_num_embeddings=config.vocab_size,
org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE
padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel
# We need bigger padding if using lora for kernel # compatibility
# compatibility if not lora_config else lora_config.lora_vocab_padding_size,
if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config,
quant_config=quant_config, )
) if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
self.scale_width = self.config.hidden_size / self.config.dim_model_base self.scale_width = self.config.hidden_size / self.config.dim_model_base
self.logits_processor = LogitsProcessor(unpadded_vocab_size, self.logits_processor = LogitsProcessor(unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def _init_model(self): def _init_model(self):
self.model = MiniCPMModel(config=self.config, self.model = MiniCPMModel(config=self.config,
...@@ -484,7 +512,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -484,7 +512,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
...@@ -495,11 +523,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -495,11 +523,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
hidden_states = hidden_states / self.scale_width hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings: logits = self.logits_processor(self.lm_head, hidden_states,
lm_head = self.model.embed_tokens
else:
lm_head = self.lm_head
logits = self.logits_processor(lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
...@@ -548,6 +572,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -548,6 +572,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -557,6 +583,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -557,6 +583,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, weight_loader(param,
...@@ -568,6 +596,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -568,6 +596,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -26,6 +26,7 @@ from typing import Any, Dict, Optional ...@@ -26,6 +26,7 @@ from typing import Any, Dict, Optional
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
...@@ -34,19 +35,20 @@ from vllm.model_executor.layers.layernorm import RMSNorm ...@@ -34,19 +35,20 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer, from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer,
MiniCPMForCausalLM, MiniCPMForCausalLM,
MiniCPMModel) MiniCPMModel)
from .utils import make_layers
class MiniCPM3Attention(nn.Module): class MiniCPM3Attention(nn.Module):
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
qk_nope_head_dim: int, qk_nope_head_dim: int,
...@@ -199,15 +201,43 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer): ...@@ -199,15 +201,43 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):
class MiniCPM3Model(MiniCPMModel): class MiniCPM3Model(MiniCPMModel):
def _init_layers(self): def _init_layers(
self.layers = nn.ModuleList([ self,
MiniCPM3DecoderLayer(self.config, self.cache_config, prefix: str,
self.quant_config) config: PretrainedConfig,
for _ in range(self.config.num_hidden_layers) cache_config: Optional[CacheConfig],
]) quant_config: Optional[QuantizationConfig],
):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MiniCPM3DecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")
class MiniCPM3ForCausalLM(MiniCPMForCausalLM): class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
packed_modules_mapping = {
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"kv_a_proj_with_mqa",
"q_a_proj",
"q_b_proj",
"kv_b_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
# `embedding_modules` and `embedding_padding_modules`
# are inherited from MiniCPMForCausalLM
def _init_model(self): def _init_model(self):
self.model = MiniCPM3Model(config=self.config, self.model = MiniCPM3Model(config=self.config,
......
...@@ -24,33 +24,33 @@ ...@@ -24,33 +24,33 @@
import math import math
import re import re
from functools import partial from functools import partial
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple, from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
TypedDict) Tuple, TypedDict, Union)
import torch import torch
import torch.types import torch.types
from PIL import Image from PIL import Image
from torch import nn from torch import nn
from torch.nn.init import trunc_normal_
from transformers import PretrainedConfig from transformers import PretrainedConfig
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
from vllm.model_executor.layers.linear import ReplicatedLinear token_inputs)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (Resampler2, from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
get_2d_sincos_pos_embed) get_2d_sincos_pos_embed)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.minicpm import MiniCPMModel from vllm.model_executor.models.minicpm import MiniCPMModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.models.utils import LLMWrapper
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs
...@@ -59,16 +59,19 @@ from vllm.multimodal.utils import cached_get_tokenizer ...@@ -59,16 +59,19 @@ from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import is_pp_missing_parameter
_KEYS_TO_MODIFY_MAPPING = { _KEYS_TO_MODIFY_MAPPING = {
"llm.lm_head": "lm_head", "llm.lm_head": "lm_head",
"llm.model": "llm",
} }
RawImageType = Union[Image.Image, torch.Tensor]
class MiniCPMVImageInput(TypedDict):
class MiniCPMVRawImageInput(TypedDict):
"""Input mapper input with auxiliary data for computing image bounds.""" """Input mapper input with auxiliary data for computing image bounds."""
image: Image.Image image: RawImageType
# Image bounds token ids in 0-dim scaler tensor. # Image bounds token ids in 0-dim scaler tensor.
im_start_id: torch.Tensor im_start_id: torch.Tensor
...@@ -78,7 +81,8 @@ class MiniCPMVImageInput(TypedDict): ...@@ -78,7 +81,8 @@ class MiniCPMVImageInput(TypedDict):
class MiniCPMVImagePixelInputs(TypedDict): class MiniCPMVImagePixelInputs(TypedDict):
pixel_values: List[torch.Tensor] type: Literal["pixel_values"]
data: List[torch.Tensor]
""" """
Shape: `(batch_size * num_images, num_channels, height, width)` Shape: `(batch_size * num_images, num_channels, height, width)`
...@@ -101,59 +105,28 @@ class MiniCPMVImagePixelInputs(TypedDict): ...@@ -101,59 +105,28 @@ class MiniCPMVImagePixelInputs(TypedDict):
""" """
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) class MiniCPMVImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor.
"""
class BaseResampler(nn.Module): image_bounds: torch.Tensor
""" """
A 2D perceiver-resampler network with one cross attention layers by Shape: `(batch_size * num_images, 2)`
(grid_size**2) learnable queries and 2d sincos pos_emb
Outputs: This should be in `(start, stop)` format.
A tensor with the shape of (grid_size**2, embed_dim)
""" """
def __init__(
self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
) -> None:
super().__init__()
self.num_queries = num_queries MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
self.embed_dim = embed_dim MiniCPMVImageEmbeddingInputs]
self.num_heads = num_heads
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
trunc_normal_(self.query, std=0.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
else:
# Maintain the same return value with ReplicatedLinear.forward
self.kv_proj = lambda *args, **kwargs: (
nn.Identity()(*args, **kwargs),
None,
)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.ln_post = norm_layer(embed_dim)
self.proj = nn.Parameter(
(embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
class Resampler2_5(BaseResampler): class Resampler2_5(BaseResampler):
...@@ -246,22 +219,22 @@ class Resampler2_5(BaseResampler): ...@@ -246,22 +219,22 @@ class Resampler2_5(BaseResampler):
def _build_image_input(ctx: InputContext, def _build_image_input(ctx: InputContext,
image: Image.Image) -> MiniCPMVImageInput: image: RawImageType) -> MiniCPMVRawImageInput:
tokenizer = cached_get_tokenizer( tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer, ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code) trust_remote_code=ctx.model_config.trust_remote_code)
if hasattr(tokenizer, "slice_start_id"): if hasattr(tokenizer, "slice_start_id"):
return MiniCPMVImageInput( return MiniCPMVRawImageInput(
image=image, image=image,
im_start_id=torch.tensor(tokenizer.im_start_id), im_start_id=torch.tensor(tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id), im_end_id=torch.tensor(tokenizer.im_end_id),
slice_start_id=torch.tensor(tokenizer.slice_start_id), slice_start_id=torch.tensor(tokenizer.slice_start_id),
slice_end_id=torch.tensor(tokenizer.slice_end_id)) slice_end_id=torch.tensor(tokenizer.slice_end_id))
else: else:
return MiniCPMVImageInput(image=image, return MiniCPMVRawImageInput(
im_start_id=torch.tensor( image=image,
tokenizer.im_start_id), im_start_id=torch.tensor(tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id)) im_end_id=torch.tensor(tokenizer.im_end_id))
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
...@@ -284,7 +257,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext): ...@@ -284,7 +257,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int): def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
return SequenceData.from_token_counts((0, seq_len)) return SequenceData.from_prompt_token_counts((0, seq_len))
def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig, def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig,
...@@ -307,10 +280,10 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, ...@@ -307,10 +280,10 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
return seq_data, mm_data return seq_data, mm_data
def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
model_config = ctx.model_config model_config = ctx.model_config
version = get_version_by_config(model_config.hf_config) version = get_version_by_config(model_config.hf_config)
tokenizer = cached_get_tokenizer( tokenizer = cached_get_tokenizer(
...@@ -325,27 +298,32 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -325,27 +298,32 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return image_processor. \ return image_processor. \
get_slice_image_placeholder(image_size, num_image) get_slice_image_placeholder(image_size, num_image)
prompt = llm_inputs.get("prompt") prompt = inputs.get("prompt")
token_ids = inputs.get("prompt_token_ids")
if prompt is None: if prompt is None:
token_ids = llm_inputs.get("prompt_token_ids")
prompt = tokenizer.decode(token_ids) prompt = tokenizer.decode(token_ids)
pattern = "(<image>./</image>)" pattern = "(<image>./</image>)"
images = multi_modal_data["image"] images = multi_modal_data["image"]
if isinstance(images, Image.Image):
images = [images]
image_tags = re.findall(pattern, prompt) image_tags = re.findall(pattern, prompt)
if len(image_tags) == 0: if len(image_tags) == 0:
new_token_ids = token_ids new_token_ids = token_ids
new_prompt = prompt new_prompt = prompt
else: else:
if isinstance(images, dict):
image_size_list = images.get("image_size_list")
images = [images.get("image_embeds")]
else:
if isinstance(images, Image.Image):
images = [images]
image_size_list = [image.size for image in images]
text_chunks = prompt.split(pattern) text_chunks = prompt.split(pattern)
new_prompt_chunks: List[str] = [] new_prompt_chunks: List[str] = []
for i in range(len(images)): for i in range(len(image_size_list)):
new_prompt_chunks += [ new_prompt_chunks += [
text_chunks[i], text_chunks[i],
get_placeholder(images[i].size, i) get_placeholder(image_size_list[i], i)
] ]
new_prompt_chunks.append(text_chunks[-1]) new_prompt_chunks.append(text_chunks[-1])
new_prompt = "".join(new_prompt_chunks) new_prompt = "".join(new_prompt_chunks)
...@@ -355,12 +333,11 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -355,12 +333,11 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
_build_image_input(ctx, image) for image in images _build_image_input(ctx, image) for image in images
] ]
llm_inputs = LLMInputs( return token_inputs(
prompt_token_ids=new_token_ids, prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data, multi_modal_data=multi_modal_data,
) )
return llm_inputs
def input_mapper_for_minicpmv(ctx: InputContext, data: object): def input_mapper_for_minicpmv(ctx: InputContext, data: object):
...@@ -375,9 +352,15 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object): ...@@ -375,9 +352,15 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object):
if not isinstance(data, list): if not isinstance(data, list):
raise ValueError( raise ValueError(
"Image input must be list of MiniCPMVImageInput, got (%s)", data) "Image input must be list of MiniCPMVImageInput, got (%s)", data)
batch_data = image_processor \
.preprocess([img["image"] for img in data], return_tensors="pt") \ if len(data) > 0 and isinstance(data[0]['image'], torch.Tensor):
.data batch_data = {
"image_embeds": data[0]['image'],
}
else:
batch_data = image_processor \
.preprocess([img["image"] for img in data], return_tensors="pt") \
.data
if len(data) > 0: if len(data) > 0:
batch_data["im_start_id"] = data[0]["im_start_id"] batch_data["im_start_id"] = data[0]["im_start_id"]
...@@ -389,7 +372,7 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object): ...@@ -389,7 +372,7 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object):
return MultiModalInputs(batch_data) return MultiModalInputs(batch_data)
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
""" """
The abstract class of MiniCPMV can only be inherited, but cannot be The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated. instantiated.
...@@ -426,10 +409,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -426,10 +409,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.llm.make_empty_intermediate_tensors)
def get_embedding( def get_embedding(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImagePixelInputs], image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids) vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
if hasattr(self.config, "scale_emb"): if hasattr(self.config, "scale_emb"):
...@@ -438,7 +424,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -438,7 +424,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
if image_inputs is None: # No image if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device) vision_hidden_states = torch.tensor([], device=input_ids.device)
else: else:
vision_hidden_states = self.get_vision_hidden_states(image_inputs) if image_inputs["type"] == "image_embeds":
vision_hidden_states = (image_inputs["data"].type(
vlm_embedding.dtype).to(vlm_embedding.device))
else:
vision_hidden_states = self.get_vision_hidden_states(
image_inputs)
# See NOTE in _parse_and_validate_inputs # See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"] image_bounds = image_inputs["image_bounds"]
...@@ -489,9 +480,23 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -489,9 +480,23 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
**kwargs: object, **kwargs: object,
) -> Optional[MiniCPMVImagePixelInputs]: ) -> Optional[MiniCPMVImageInputs]:
pixel_values = kwargs.pop("pixel_values", []) pixel_values = kwargs.pop("pixel_values", [])
tgt_sizes = kwargs.pop("tgt_sizes", []) tgt_sizes = kwargs.pop("tgt_sizes", [])
im_start_id = kwargs.pop("im_start_id", None)
im_end_id = kwargs.pop("im_end_id", None)
slice_start_id = kwargs.pop("slice_start_id", None)
slice_end_id = kwargs.pop("slice_end_id", None)
image_embeds = kwargs.pop("image_embeds", None)
if image_embeds is not None:
return MiniCPMVImageEmbeddingInputs(
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
data=image_embeds,
type="image_embeds",
)
if not isinstance(pixel_values, (torch.Tensor, list)): if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
...@@ -526,10 +531,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -526,10 +531,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
if len(pixel_values_flat) == 0: if len(pixel_values_flat) == 0:
return None return None
im_start_id = kwargs.pop("im_start_id", None)
im_end_id = kwargs.pop("im_end_id", None)
slice_start_id = kwargs.pop("slice_start_id", None)
slice_end_id = kwargs.pop("slice_end_id", None)
if im_start_id is None: if im_start_id is None:
return None return None
...@@ -537,8 +538,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -537,8 +538,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
image_bounds=self._get_image_bounds(input_ids, im_start_id, image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id, im_end_id, slice_start_id,
slice_end_id), slice_end_id),
pixel_values=pixel_values_flat, data=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat), tgt_sizes=torch.stack(tgt_sizes_flat),
type="pixel_values",
) )
def forward( def forward(
...@@ -550,9 +552,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -550,9 +552,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs) if intermediate_tensors is not None:
vlm_embeddings = None
else:
image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
output = self.llm( output = self.llm(
input_ids=None, input_ids=None,
...@@ -609,6 +614,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -609,6 +614,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
if is_pp_missing_parameter(
name.replace(weight_name, param_name), self):
continue
param = params_dict[name.replace(weight_name, param_name)] param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -616,11 +624,21 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -616,11 +624,21 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
else: else:
use_default_weight_loading = True use_default_weight_loading = True
if use_default_weight_loading: if use_default_weight_loading:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(language_model="llm",
connector="resampler",
tower_model="vpm")
def init_llm( def init_llm(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
...@@ -643,8 +661,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -643,8 +661,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def get_vision_hidden_states( def get_vision_hidden_states(self,
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: data: MiniCPMVImageInputs) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def is_default_weight_loading(self, name: str) -> bool: def is_default_weight_loading(self, name: str) -> bool:
...@@ -669,9 +687,11 @@ class MiniCPMV2_0(MiniCPMVBaseModel): ...@@ -669,9 +687,11 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> nn.Module: ) -> nn.Module:
return MiniCPMModel(config,
cache_config=cache_config, return LLMWrapper(MiniCPMModel(config,
quant_config=quant_config) cache_config=cache_config,
quant_config=quant_config),
name="model")
def init_vision_module(self) -> nn.Module: def init_vision_module(self) -> nn.Module:
# TODO :refactor this vision model # TODO :refactor this vision model
...@@ -697,6 +717,9 @@ class MiniCPMV2_0(MiniCPMVBaseModel): ...@@ -697,6 +717,9 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
return model return model
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_tokens(input_ids)
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
with set_default_torch_dtype(torch.float16): with set_default_torch_dtype(torch.float16):
resampler = Resampler2( resampler = Resampler2(
...@@ -733,9 +756,9 @@ class MiniCPMV2_0(MiniCPMVBaseModel): ...@@ -733,9 +756,9 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
res.append(self.resampler(vision_embedding, tgt_size)) res.append(self.resampler(vision_embedding, tgt_size))
return torch.vstack(res) return torch.vstack(res)
def get_vision_hidden_states( def get_vision_hidden_states(self,
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["pixel_values"] pixel_values = data["data"]
return self.get_vision_embedding(pixel_values) return self.get_vision_embedding(pixel_values)
...@@ -743,7 +766,34 @@ class MiniCPMV2_0(MiniCPMVBaseModel): ...@@ -743,7 +766,34 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
return "resampler" in name or "vpm" in name return "resampler" in name or "vpm" in name
class MiniCPMV2_5(MiniCPMVBaseModel): class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision encoder
"fc1",
"fc2",
"out_proj",
# language model
"qkv_proj", # same name with vision encoder
"o_proj",
"gate_up_proj",
"down_proj",
# resampler
"kv_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__( def __init__(
self, self,
...@@ -751,6 +801,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel): ...@@ -751,6 +801,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
multimodal_config: MultiModalConfig, multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__(config, multimodal_config, cache_config, quant_config) super().__init__(config, multimodal_config, cache_config, quant_config)
assert self.version == (2, 5) assert self.version == (2, 5)
...@@ -761,9 +812,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel): ...@@ -761,9 +812,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> nn.Module: ) -> nn.Module:
return LlamaModel(config, return LLMWrapper(LlamaModel(config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config),
name="model")
def init_vision_module(self) -> nn.Module: def init_vision_module(self) -> nn.Module:
model = Idefics2VisionTransformer(self.config.vision_config) model = Idefics2VisionTransformer(self.config.vision_config)
...@@ -792,9 +844,9 @@ class MiniCPMV2_5(MiniCPMVBaseModel): ...@@ -792,9 +844,9 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
vision_embedding = self.resampler(vision_embedding, tgt_sizes) vision_embedding = self.resampler(vision_embedding, tgt_sizes)
return vision_embedding return vision_embedding
def get_vision_hidden_states( def get_vision_hidden_states(self,
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["pixel_values"] pixel_values = data["data"]
tgt_sizes = data["tgt_sizes"] tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device device = self.vpm.embeddings.position_embedding.weight.device
...@@ -825,7 +877,35 @@ class MiniCPMV2_5(MiniCPMVBaseModel): ...@@ -825,7 +877,35 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
return "resampler" in name return "resampler" in name
class MiniCPMV2_6(MiniCPMVBaseModel): class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision encoder
"fc1",
"fc2",
"out_proj",
# language model
"qkv_proj", # same name with vision encoder
"o_proj",
"gate_up_proj",
"down_proj",
# resampler
"kv_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__( def __init__(
self, self,
...@@ -843,20 +923,15 @@ class MiniCPMV2_6(MiniCPMVBaseModel): ...@@ -843,20 +923,15 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> nn.Module: ) -> nn.Module:
return Qwen2Model(config,
cache_config=cache_config, return LLMWrapper(Qwen2Model(config,
quant_config=quant_config) cache_config=cache_config,
quant_config=quant_config),
name="model")
def init_vision_module(self) -> nn.Module: def init_vision_module(self) -> nn.Module:
# A custom version of SiglipVisionTransformer, won't work with TP
from vllm.model_executor.models.na_vit import SiglipVisionTransformer
if self.config._attn_implementation == "flash_attention_2": model = Idefics2VisionTransformer(self.config.vision_config)
self.config.vision_config._attn_implementation = "flash_attention_2"
else:
# not support sdpa
self.config.vision_config._attn_implementation = "eager"
model = SiglipVisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer: if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1] model.encoder.layers = model.encoder.layers[:-1]
return model return model
...@@ -870,7 +945,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel): ...@@ -870,7 +945,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
num_heads=embed_dim // 128, num_heads=embed_dim // 128,
kv_dim=vision_dim, kv_dim=vision_dim,
) )
return resampler return resampler
def get_vision_embedding( def get_vision_embedding(
...@@ -883,12 +957,12 @@ class MiniCPMV2_6(MiniCPMVBaseModel): ...@@ -883,12 +957,12 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
pixel_values, pixel_values,
patch_attention_mask=patch_attn_mask, patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes, tgt_sizes=tgt_sizes,
).last_hidden_state )
return vision_embedding return vision_embedding
def get_vision_hidden_states( def get_vision_hidden_states(self,
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["pixel_values"] pixel_values = data["data"]
tgt_sizes = data["tgt_sizes"] tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device device = self.vpm.embeddings.position_embedding.weight.device
...@@ -915,12 +989,12 @@ class MiniCPMV2_6(MiniCPMVBaseModel): ...@@ -915,12 +989,12 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
all_pixel_values.type(dtype), all_pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask, patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes, tgt_sizes=tgt_sizes,
).last_hidden_state )
return self.resampler(vision_embedding, tgt_sizes) return self.resampler(vision_embedding, tgt_sizes)
def is_default_weight_loading(self, name: str) -> bool: def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name or "vpm" in name return "resampler" in name
_SUPPORT_VERSION = { _SUPPORT_VERSION = {
...@@ -934,20 +1008,25 @@ _SUPPORT_VERSION = { ...@@ -934,20 +1008,25 @@ _SUPPORT_VERSION = {
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv) @INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv)
class MiniCPMV(MiniCPMVBaseModel): class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
""" """
Different versions of MiniCPMV use different visual encoders and LLMs, Different versions of MiniCPMV use different visual encoders and LLMs,
which is not conducive to the current integration logic of LoRA and which is not conducive to the current integration logic of LoRA and
bitsandbytes in vLLM. Therefore, it is necessary to separate them. bitsandbytes in vLLM. Therefore, it is necessary to separate them.
""" """
# Ensure that the LoRA support check passes when the class is not
def __new__( # initialized, but set all these attributes to empty.
cls, packed_modules_mapping = {}
config: PretrainedConfig, supported_lora_modules = []
multimodal_config: MultiModalConfig, embedding_modules = {}
cache_config: Optional[CacheConfig] = None, embedding_padding_modules = []
quant_config: Optional[QuantizationConfig] = None,
): def __new__(cls,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None):
if not hasattr(config, "version"): if not hasattr(config, "version"):
if config.hidden_size == 2304 and config.query_num == 64: if config.hidden_size == 2304 and config.query_num == 64:
version = (2, 0) version = (2, 0)
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -36,8 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, ...@@ -36,8 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -47,8 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -47,8 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import is_pp_missing_parameter, make_layers from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class MixtralMoE(nn.Module): class MixtralMoE(nn.Module):
...@@ -276,6 +276,9 @@ class MixtralModel(nn.Module): ...@@ -276,6 +276,9 @@ class MixtralModel(nn.Module):
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
...@@ -284,7 +287,7 @@ class MixtralModel(nn.Module): ...@@ -284,7 +287,7 @@ class MixtralModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
...@@ -306,7 +309,7 @@ class MixtralModel(nn.Module): ...@@ -306,7 +309,7 @@ class MixtralModel(nn.Module):
return hidden_states return hidden_states
class MixtralForCausalLM(nn.Module, SupportsLoRA): class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
packed_modules_mapping = { packed_modules_mapping = {
...@@ -319,10 +322,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -319,10 +322,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
# LoRA specific attributes # LoRA specific attributes
supported_lora_modules = [ supported_lora_modules = [
"qkv_proj", "qkv_proj", "o_proj", "embed_tokens", "lm_head", "w1", "w2", "w3",
"o_proj", "gate"
"embed_tokens",
"lm_head",
] ]
embedding_modules = { embedding_modules = {
"embed_tokens": "input_embeddings", "embed_tokens": "input_embeddings",
...@@ -365,6 +366,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -365,6 +366,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -373,7 +376,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -373,7 +376,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
...@@ -387,20 +390,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -387,20 +390,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata) sampling_metadata)
return logits return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample( def sample(
self, self,
logits: Optional[torch.Tensor], logits: Optional[torch.Tensor],
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -31,7 +31,7 @@ from transformers import MixtralConfig ...@@ -31,7 +31,7 @@ from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -39,8 +39,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, ...@@ -39,8 +39,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -49,6 +48,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -49,6 +48,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class MixtralMLP(nn.Module): class MixtralMLP(nn.Module):
...@@ -296,6 +299,7 @@ class MixtralModel(nn.Module): ...@@ -296,6 +299,7 @@ class MixtralModel(nn.Module):
config: MixtralConfig, config: MixtralConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -305,13 +309,15 @@ class MixtralModel(nn.Module): ...@@ -305,13 +309,15 @@ class MixtralModel(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
MixtralDecoderLayer(config, config.num_hidden_layers,
cache_config, lambda prefix: MixtralDecoderLayer(
quant_config=quant_config) config, cache_config, quant_config=quant_config),
for _ in range(config.num_hidden_layers) prefix=f"{prefix}.layers")
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
...@@ -319,19 +325,30 @@ class MixtralModel(nn.Module): ...@@ -319,19 +325,30 @@ class MixtralModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_tokens(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
residual = None if get_pp_group().is_first_rank:
for i in range(len(self.layers)): hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], attn_metadata, kv_caches[i - self.start_layer],
residual) attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class MixtralForCausalLM(nn.Module): class MixtralForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
def __init__( def __init__(
...@@ -351,6 +368,8 @@ class MixtralForCausalLM(nn.Module): ...@@ -351,6 +368,8 @@ class MixtralForCausalLM(nn.Module):
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -359,9 +378,9 @@ class MixtralForCausalLM(nn.Module): ...@@ -359,9 +378,9 @@ class MixtralForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -400,6 +419,8 @@ class MixtralForCausalLM(nn.Module): ...@@ -400,6 +419,8 @@ class MixtralForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -412,6 +433,8 @@ class MixtralForCausalLM(nn.Module): ...@@ -412,6 +433,8 @@ class MixtralForCausalLM(nn.Module):
if ("block_sparse_moe.experts." in name if ("block_sparse_moe.experts." in name
and name not in params_dict): and name not in params_dict):
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
# limitations under the License. # limitations under the License.
"""PyTorch Mllama model.""" """PyTorch Mllama model."""
import math import math
from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -28,12 +28,16 @@ from transformers.modeling_outputs import (BaseModelOutput, ...@@ -28,12 +28,16 @@ from transformers.modeling_outputs import (BaseModelOutput,
CausalLMOutputWithPast) CausalLMOutputWithPast)
from transformers.models.mllama.image_processing_mllama import ( from transformers.models.mllama.image_processing_mllama import (
get_optimal_tiled_canvas) get_optimal_tiled_canvas)
from transformers.models.mllama.processing_mllama import (
get_cross_attention_token_mask)
import vllm.distributed.parallel_state as ps import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
EncoderDecoderInputs, InputContext)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -47,7 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -47,7 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import SequenceData
from .clip import CLIPMLP from .clip import CLIPMLP
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
...@@ -72,31 +76,45 @@ class MllamaImagePixelInputs(TypedDict): ...@@ -72,31 +76,45 @@ class MllamaImagePixelInputs(TypedDict):
# TODO: support LlamaImageEmbeddingInputs # TODO: support LlamaImageEmbeddingInputs
def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int:
num_images = 0
for token_id in prompt_token_ids[::-1]:
if token_id == MLLAMA_IMAGE_TOKEN_ID:
num_images += 1
elif num_images > 0:
break
return num_images
def input_processor_for_mllama(ctx: InputContext,
inputs: Union[DecoderOnlyInputs,
EncoderDecoderInputs]):
# move encoder_prompt to prompt # move encoder_prompt to prompt
if llm_inputs.get("prompt") is None: if inputs.get("prompt") is None:
llm_inputs["prompt"] = llm_inputs["encoder_prompt"] inputs["prompt"] = inputs["encoder_prompt"]
llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] inputs["prompt_token_ids"] = inputs["encoder_prompt_token_ids"]
# process multi-modal data # process multi-modal data
assert "decoder_multi_modal_data" not in llm_inputs, \ multi_modal_data = inputs.get("encoder_multi_modal_data")
"multi-modal data should be put in encoder message of mllama"
multi_modal_data = llm_inputs.get("encoder_multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data \ if multi_modal_data is None or "image" not in multi_modal_data \
or multi_modal_data["image"] is None: or multi_modal_data["image"] is None:
# text-only # text-only
llm_inputs["encoder_prompt"] = "" inputs["encoder_prompt"] = ""
llm_inputs["encoder_prompt_token_ids"] = [] inputs["encoder_prompt_token_ids"] = []
llm_inputs["encoder_multi_modal_data"] = {} inputs["encoder_multi_modal_data"] = {}
return llm_inputs return inputs
# get num_tiles
if isinstance(multi_modal_data['image'], Image.Image): if isinstance(multi_modal_data['image'], Image.Image):
multi_modal_data['image'] = [multi_modal_data['image']] multi_modal_data['image'] = [multi_modal_data['image']]
# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tiles for those images.
num_decode_images = _get_num_image_in_last_group(
inputs["prompt_token_ids"])
hf_config = ctx.model_config.hf_config hf_config = ctx.model_config.hf_config
num_tiles = 0 num_tiles = 0
for image in multi_modal_data["image"]: for image in multi_modal_data["image"][::-1]:
width, height = image.size width, height = image.size
tile_size = hf_config.vision_config.image_size tile_size = hf_config.vision_config.image_size
canvas_height, canvas_width = get_optimal_tiled_canvas( canvas_height, canvas_width = get_optimal_tiled_canvas(
...@@ -108,17 +126,21 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -108,17 +126,21 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
num_tiles_height = canvas_height // tile_size num_tiles_height = canvas_height // tile_size
num_tiles_width = canvas_width // tile_size num_tiles_width = canvas_width // tile_size
num_tiles += num_tiles_height * num_tiles_width num_tiles += num_tiles_height * num_tiles_width
num_decode_images -= 1
if num_decode_images == 0:
break
# set encoder prompt based on num_tiles # Set encoder prompt length based on the number of tiles.
# This tells the block manager to allocate correct number
# of slots for encoder tokens.
assert hf_config.vision_config.image_size % 14 == 0, \ assert hf_config.vision_config.image_size % 14 == 0, \
"chunk size should be multiple of 14" "chunk size should be multiple of 14"
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
num_tokens = num_tiles * token_per_chunk num_tokens = num_tiles * token_per_chunk
llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens
llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID] * num_tokens
] * num_tokens
return llm_inputs return inputs
def get_max_mllama_image_tokens(ctx: InputContext) -> int: def get_max_mllama_image_tokens(ctx: InputContext) -> int:
...@@ -131,17 +153,18 @@ def dummy_decoder_seq_data(seq_len: int, num_images: int): ...@@ -131,17 +153,18 @@ def dummy_decoder_seq_data(seq_len: int, num_images: int):
# <|image|> * num_images + 0 * (seq_len - num_images) # <|image|> * num_images + 0 * (seq_len - num_images)
assert seq_len >= num_images, \ assert seq_len >= num_images, \
"seq_len should be greater than or equal to num_images" "seq_len should be greater than or equal to num_images"
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[MLLAMA_IMAGE_TOKEN_ID]) * num_images return SequenceData.from_prompt_token_counts(
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - num_images) (MLLAMA_IMAGE_TOKEN_ID, num_images),
return SequenceData(token_ids) (0, seq_len - num_images),
)
def dummy_encoder_seq_data(ctx: InputContext, num_images: int): def dummy_encoder_seq_data(ctx: InputContext, num_images: int):
num_tokens = get_max_mllama_image_tokens(ctx) * num_images num_tokens = get_max_mllama_image_tokens(ctx) * num_images
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[MLLAMA_IMAGE_TOKEN_ID]) * num_tokens return SequenceData.from_prompt_token_counts(
return SequenceData(token_ids) (MLLAMA_IMAGE_TOKEN_ID, num_tokens))
def dummy_image(num_images: int, ): def dummy_image(num_images: int, ):
...@@ -675,6 +698,7 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -675,6 +698,7 @@ class MllamaTextCrossAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]],
cross_attention_states: Optional[torch.Tensor], cross_attention_states: Optional[torch.Tensor],
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
...@@ -697,15 +721,71 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -697,15 +721,71 @@ class MllamaTextCrossAttention(nn.Module):
q = q.view(-1, self.num_local_heads, self.head_dim) q = q.view(-1, self.num_local_heads, self.head_dim)
q = self.q_norm(q) q = self.q_norm(q)
output = self.attn(q, if attention_mask is not None:
k, output = self.attention_with_mask(q, k, v, kv_cache,
v, attention_mask,
kv_cache, kv_range_for_decode,
attn_metadata, attn_metadata)
attn_type=AttentionType.ENCODER_DECODER) else:
output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_DECODER)
out, _ = self.o_proj(output) out, _ = self.o_proj(output)
return out return out
def attention_with_mask(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
attention_mask: torch.Tensor,
kv_range_for_decode: List[Tuple[int, int]],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) == 3:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
# standard causal mask, neither a block diagonal mask which
# can be optimized by xformers.BlockDiagonalMask.
# The mask is specially calculated for supporting multi
# images and interleaved images.
q_len = q.shape[0]
kv_len = k.shape[0]
q = q.transpose(0, 1).view(self.num_local_key_value_heads,
self.num_key_value_groups, q_len,
self.head_dim)
k = k.transpose(0,
1)[:,
None, :, :].expand(self.num_local_key_value_heads,
self.num_key_value_groups,
kv_len, self.head_dim)
v = v.transpose(0,
1)[:,
None, :, :].expand(self.num_local_key_value_heads,
self.num_key_value_groups,
kv_len, self.head_dim)
attention_mask = attention_mask.view(1, 1, q_len, kv_len)
output = F.scaled_dot_product_attention(q,
k,
v,
attn_mask=attention_mask,
is_causal=False)
output = output.permute(2, 0, 1, 3).reshape(
q_len, self.num_local_heads * self.head_dim)
return output
class MllamaCrossAttentionDecoderLayer(torch.nn.Module): class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention """Cross-attention transformer block with tanh-gated attention
...@@ -741,6 +821,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): ...@@ -741,6 +821,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor, cross_attention_states: torch.Tensor,
cross_attention_mask: torch.Tensor, cross_attention_mask: torch.Tensor,
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: torch.Tensor, full_text_row_masked_out_mask: torch.Tensor,
kv_cache: List[torch.Tensor], kv_cache: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
...@@ -751,6 +832,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): ...@@ -751,6 +832,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
hidden_states = self.cross_attn( hidden_states = self.cross_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=cross_attention_mask, attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
kv_cache=kv_cache, kv_cache=kv_cache,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
...@@ -804,6 +886,7 @@ class MllamaTextModel(nn.Module): ...@@ -804,6 +886,7 @@ class MllamaTextModel(nn.Module):
positions: Optional[torch.LongTensor], positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]], torch.Tensor]],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
...@@ -820,6 +903,7 @@ class MllamaTextModel(nn.Module): ...@@ -820,6 +903,7 @@ class MllamaTextModel(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask, cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask= full_text_row_masked_out_mask=
full_text_row_masked_out_mask, full_text_row_masked_out_mask,
kv_cache=kv_caches[idx], kv_cache=kv_caches[idx],
...@@ -868,6 +952,7 @@ class MllamaForCausalLM(nn.Module): ...@@ -868,6 +952,7 @@ class MllamaForCausalLM(nn.Module):
positions: Optional[torch.LongTensor], positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]], torch.Tensor]],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
...@@ -879,6 +964,7 @@ class MllamaForCausalLM(nn.Module): ...@@ -879,6 +964,7 @@ class MllamaForCausalLM(nn.Module):
positions=positions, positions=positions,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask, cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
...@@ -1026,36 +1112,102 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1026,36 +1112,102 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def flat_encoder_result(self, cross_attention_states: torch.Tensor, def flat_encoder_result(self, cross_attention_states: torch.Tensor,
attn_metadata: AttentionMetadata): attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: List[int]):
cross_attention_states_flat = torch.zeros( cross_attention_states_flat = torch.zeros(
sum(attn_metadata.encoder_seq_lens), sum(actual_encoder_seq_lens),
cross_attention_states.shape[-1], cross_attention_states.shape[-1],
device=cross_attention_states.device, device=cross_attention_states.device,
dtype=cross_attention_states.dtype) dtype=cross_attention_states.dtype)
start_pos = 0 start_pos = 0
for seq_len, vision_token_in_batch in zip( for seq_len, vision_token_in_batch in zip(actual_encoder_seq_lens,
attn_metadata.encoder_seq_lens, cross_attention_states): cross_attention_states):
end_pos = start_pos + seq_len end_pos = start_pos + seq_len
cross_attention_states_flat[ cross_attention_states_flat[
start_pos:end_pos] = vision_token_in_batch[:seq_len] start_pos:end_pos] = vision_token_in_batch[:seq_len]
start_pos = end_pos start_pos = end_pos
cross_attention_states = cross_attention_states_flat cross_attention_states = cross_attention_states_flat
return cross_attention_states
def get_cross_attention_states(
self,
image_inputs: MllamaImagePixelInputs,
attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: List[int],
) -> Tuple[torch.Tensor]:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values = image_inputs['data']
aspect_ratio_ids = image_inputs['aspect_ratio_ids']
aspect_ratio_mask = image_inputs['aspect_ratio_mask']
cross_attention_states = self.vision_model(pixel_values,
aspect_ratio_ids,
aspect_ratio_mask)
cross_attention_states = self.multi_modal_projector(
cross_attention_states)
bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
cross_attention_states = cross_attention_states.view(
bsz, -1, image_token_dim)
cross_attention_states = self.flat_encoder_result(
cross_attention_states, attn_metadata, actual_encoder_seq_lens)
return cross_attention_states
def get_cross_attention_mask(
self,
input_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
num_tiles: List[List[int]],
num_tokens_per_tile: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
token_ids = input_ids.tolist()
start = 0
batch_token_ids = []
for seq_len in attn_metadata.seq_lens:
batch_token_ids.append(token_ids[start:start + seq_len])
start += seq_len
sparse_mask = [
get_cross_attention_token_mask(t, MLLAMA_IMAGE_TOKEN_ID)
for t in batch_token_ids
]
# Skip generating cross-attention mask if all samples
# are text-only or have only 1 leading image.
if skip_attention_mask(sparse_mask):
return None, None
dense_mask, tile_range_for_decode = \
convert_sparse_cross_attention_mask_to_dense(
sparse_mask, num_tiles, attn_metadata.seq_lens)
cross_attention_mask = \
convert_dense_cross_attention_mask_to_tensor(
dense_mask, num_tokens_per_tile, input_ids.device, dtype)
kv_range_for_decode = [[
t[0] * num_tokens_per_tile, t[1] * num_tokens_per_tile
] for t in tile_range_for_decode]
return cross_attention_mask, kv_range_for_decode
def get_full_text_row_masked_out_mask(
self,
attn_metadata: AttentionMetadata,
device: torch.device,
) -> torch.Tensor:
full_text_row_masked_out_mask = torch.ones( full_text_row_masked_out_mask = torch.ones(
(attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool)
start_pos = 0 start_pos = 0
for seq_len, encoder_seq_len in zip( for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens,
attn_metadata.seq_lens_tensor.cpu(), attn_metadata.encoder_seq_lens):
attn_metadata.encoder_seq_lens):
if encoder_seq_len == 0: if encoder_seq_len == 0:
full_text_row_masked_out_mask[start_pos:start_pos + full_text_row_masked_out_mask[start_pos:start_pos +
seq_len] = False seq_len] = False
start_pos += seq_len start_pos += seq_len
full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
cross_attention_states.device) device)
return full_text_row_masked_out_mask
return cross_attention_states, full_text_row_masked_out_mask
def forward( def forward(
self, self,
...@@ -1069,39 +1221,54 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1069,39 +1221,54 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata.num_decode_tokens > 0: attn_metadata.num_decode_tokens > 0:
raise ValueError("Chunk prefill not supported") raise ValueError("Chunk prefill not supported")
image_inputs = self._parse_and_validate_image_input(**kwargs) image_inputs = self._parse_and_validate_image_input(**kwargs)
cross_attention_states = None
cross_attention_mask = None
kv_range_for_decode = None
# For 1) text-only prefill and decode, 2) image-present decode.
if image_inputs is None: if image_inputs is None:
cross_attention_mask = None
full_text_row_masked_out_mask = ( full_text_row_masked_out_mask = (
attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to( attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to(
input_ids.device) input_ids.device)
cross_attention_states = None
skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0 skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0
# For image-present prefill.
else: else:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values = image_inputs['data']
aspect_ratio_ids = image_inputs['aspect_ratio_ids']
aspect_ratio_mask = image_inputs['aspect_ratio_mask']
cross_attention_states = self.vision_model(pixel_values,
aspect_ratio_ids,
aspect_ratio_mask)
cross_attention_states = self.multi_modal_projector(
cross_attention_states)
bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
cross_attention_states = cross_attention_states.view(
bsz, -1, image_token_dim)
cross_attention_states, full_text_row_masked_out_mask = \
self.flat_encoder_result(cross_attention_states, attn_metadata)
skip_cross_attention = False skip_cross_attention = False
# TODO: support multi-image by this mask
cross_attention_mask = None # Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See input_processor_for_mllama() for more details.
num_tiles_tensor = kwargs.pop("num_tiles")
num_tiles = [t[0].tolist() for t in num_tiles_tensor]
num_tokens_per_tile = (self.image_size // 14)**2 + 1
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
]
for actual_len, last_group_len in zip(
actual_encoder_seq_lens, attn_metadata.encoder_seq_lens):
assert actual_len >= last_group_len
cross_attention_states = self.get_cross_attention_states(
image_inputs, attn_metadata, actual_encoder_seq_lens)
full_text_row_masked_out_mask = \
self.get_full_text_row_masked_out_mask(
attn_metadata, input_ids.device)
cross_attention_mask, kv_range_for_decode = \
self.get_cross_attention_mask(
input_ids, attn_metadata, num_tiles,
num_tokens_per_tile, cross_attention_states.dtype)
outputs = self.language_model( outputs = self.language_model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask, cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
...@@ -1140,3 +1307,76 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1140,3 +1307,76 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
for mask in sparse_mask:
# Skip text-only samples.
if len(mask) == 0:
continue
# If the sample contains more than 1 images,
# we can't skip mask.
if len(mask) != 1:
return False
# If the sample contains only 1 image,
# but the image is not the leading one,
# we can't skip mask.
if mask[0][0] != 0 or mask[0][1] != -1:
return False
return True
def convert_sparse_cross_attention_mask_to_dense(
sparse_mask: List[List[List[int]]],
num_tiles: List[List[int]],
lengths: List[int],
) -> Tuple[np.ndarray, List[Tuple[int, int]]]:
total_length = sum(lengths)
total_tiles = sum([sum(tiles) for tiles in num_tiles])
dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64)
# A list of ranges, range[i] = [start, end] means
# if the i-th sample has N tiles in total, the tiles[start, end]
# will be used for cross-attention decoding.
tile_range_for_decode = []
seq_start = 0
tile_start = 0
for masks, tiles, length in zip(sparse_mask, num_tiles, lengths):
ts, td = -1, 0
for mask, tile in zip(masks, tiles):
if len(mask) != 2:
continue
start, end = mask
end = min(end, length)
if end == -1:
end = length
if end == length:
if ts == -1:
ts = tile_start
td += tile
dense_mask[seq_start + start:seq_start + end,
tile_start:tile_start + tile] = 1
tile_start += tile
tile_range_for_decode.append((ts, ts + td))
seq_start += length
return dense_mask, tile_range_for_decode
def convert_dense_cross_attention_mask_to_tensor(
cross_attention_token_mask: np.ndarray,
num_tokens_per_tile: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
mask = torch.tensor(cross_attention_token_mask, dtype=dtype, device=device)
mask = mask.repeat_interleave(num_tokens_per_tile, dim=1)
mask = 1.0 - mask
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(dtype).min)
ninf = torch.finfo(dtype).min
full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None])
mask *= full_text_mask
# (num_prompt_tokens, num_encoder_tokens)
return mask
# Adapted from
# https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py
from dataclasses import dataclass, field
from typing import List, Union
@dataclass
class ModelKeys:
model_type: str = None
module_list: str = None
embedding: str = None
mlp: str = None
down_proj: str = None
attention: str = None
o_proj: str = None
q_proj: str = None
k_proj: str = None
v_proj: str = None
qkv_proj: str = None
qk_proj: str = None
qa_proj: str = None
qb_proj: str = None
kva_proj: str = None
kvb_proj: str = None
output: str = None
@dataclass
class MultiModelKeys(ModelKeys):
language_model: List[str] = field(default_factory=list)
connector: List[str] = field(default_factory=list)
# vision tower and audio tower
tower_model: List[str] = field(default_factory=list)
generator: List[str] = field(default_factory=list)
@staticmethod
def from_string_field(language_model: Union[str, List[str]] = None,
connector: Union[str, List[str]] = None,
tower_model: Union[str, List[str]] = None,
generator: Union[str, List[str]] = None,
**kwargs) -> 'MultiModelKeys':
def to_list(value):
if value is None:
return []
return [value] if isinstance(value, str) else list(value)
return MultiModelKeys(language_model=to_list(language_model),
connector=to_list(connector),
tower_model=to_list(tower_model),
generator=to_list(generator),
**kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment