Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6d2051cc
Commit
6d2051cc
authored
Oct 21, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev
parents
2c7f740a
a2c71c54
Changes
457
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1903 additions
and
1142 deletions
+1903
-1142
vllm/model_executor/models/jais.py
vllm/model_executor/models/jais.py
+10
-14
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+107
-345
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+262
-194
vllm/model_executor/models/llama_embedding.py
vllm/model_executor/models/llama_embedding.py
+0
-87
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+44
-42
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+44
-51
vllm/model_executor/models/llava_next_video.py
vllm/model_executor/models/llava_next_video.py
+51
-48
vllm/model_executor/models/llava_onevision.py
vllm/model_executor/models/llava_onevision.py
+63
-66
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+409
-0
vllm/model_executor/models/mamba_cache.py
vllm/model_executor/models/mamba_cache.py
+158
-0
vllm/model_executor/models/minicpm.py
vllm/model_executor/models/minicpm.py
+73
-43
vllm/model_executor/models/minicpm3.py
vllm/model_executor/models/minicpm3.py
+39
-9
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+213
-134
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+15
-26
vllm/model_executor/models/mixtral_quant.py
vllm/model_executor/models/mixtral_quant.py
+42
-19
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+304
-64
vllm/model_executor/models/module_mapping.py
vllm/model_executor/models/module_mapping.py
+69
-0
No files found.
Too many changes to show.
To preserve performance only
457 of 457+
files are displayed.
Plain diff
Email patch
vllm/model_executor/models/jais.py
View file @
6d2051cc
...
...
@@ -33,8 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
...
...
@@ -43,7 +42,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs
import
JAISConfig
from
.utils
import
is_pp_missing_parameter
,
make_layers
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
class
SwiGLUActivation
(
nn
.
Module
):
...
...
@@ -244,6 +245,9 @@ class JAISModel(nn.Module):
)
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
n_embd
))
def
forward
(
self
,
...
...
@@ -279,7 +283,7 @@ class JAISModel(nn.Module):
return
hidden_states
class
JAISLMHeadModel
(
nn
.
Module
):
class
JAISLMHeadModel
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
...
...
@@ -304,6 +308,8 @@ class JAISLMHeadModel(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
vocab_size
=
config
.
vocab_size
,
scale
=
self
.
output_logits_scale
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
transformer
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
...
...
@@ -326,16 +332,6 @@ class JAISLMHeadModel(nn.Module):
sampling_metadata
)
return
logits
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
sample
(
self
,
logits
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/jamba.py
View file @
6d2051cc
# coding=utf-8
"""Inference-only Jamba model."""
from
dataclasses
import
dataclass
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
from
transformers
import
JambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -25,31 +22,25 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn
,
causal_conv1d_update
)
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
selective_scan_fn
,
selective_state_update
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
HasInnerState
from
vllm.model_executor.model_loader.weight_utils
import
(
composed_weight_loader
,
default_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
from
vllm.worker.model_runner
import
(
_BATCH_SIZES_TO_CAPTURE
,
_get_graph_batch_size
)
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
HasInnerState
,
SupportsLoRA
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
@
dataclass
class
MambaCacheParams
:
is_prompt
:
bool
=
False
conv_state
:
torch
.
Tensor
=
torch
.
Tensor
()
ssm_state
:
torch
.
Tensor
=
torch
.
Tensor
()
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class
JambaMambaMixer
(
nn
.
Module
):
"""
...
...
@@ -62,10 +53,9 @@ class JambaMambaMixer(nn.Module):
**selective** state spaces)
"""
def
__init__
(
self
,
config
:
JambaConfig
,
layer_idx
):
def
__init__
(
self
,
config
:
JambaConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
layer_idx
=
layer_idx
self
.
hidden_size
=
config
.
hidden_size
self
.
ssm_state_size
=
config
.
mamba_d_state
self
.
conv_kernel_size
=
config
.
mamba_d_conv
...
...
@@ -101,16 +91,6 @@ class JambaMambaMixer(nn.Module):
bias
=
True
,
skip_bias_add
=
True
)
def
weight_loader
(
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
param
.
data
.
copy_
(
loaded_weight
.
data
.
split
(
loaded_weight
.
shape
[
0
]
//
tp_size
,
dim
=
0
)[
tp_rank
])
def
A_weight_loader
(
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
weight_loader
(
param
,
-
torch
.
exp
(
loaded_weight
.
float
()))
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
A
=
nn
.
Parameter
(
torch
.
empty
(
...
...
@@ -120,8 +100,10 @@ class JambaMambaMixer(nn.Module):
))
self
.
D
=
nn
.
Parameter
(
torch
.
ones
(
self
.
intermediate_size
//
tp_size
))
set_weight_attrs
(
self
.
D
,
{
"weight_loader"
:
weight_loader
})
set_weight_attrs
(
self
.
A
,
{
"weight_loader"
:
A_weight_loader
})
set_weight_attrs
(
self
.
D
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
a_weight_loader
=
composed_weight_loader
(
sharded_weight_loader
(
0
),
lambda
x
:
-
torch
.
exp
(
x
.
float
()))
set_weight_attrs
(
self
.
A
,
{
"weight_loader"
:
a_weight_loader
})
self
.
out_proj
=
RowParallelLinear
(
self
.
intermediate_size
,
...
...
@@ -138,42 +120,48 @@ class JambaMambaMixer(nn.Module):
self
.
c_layernorm
=
RMSNorm
(
self
.
ssm_state_size
,
eps
=
config
.
rms_norm_eps
)
def
mamba_forward
(
self
,
hidden_states
:
torch
.
Tensor
,
cache_params
:
MambaCacheParams
=
None
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
):
# 1. Gated MLP's linear projection
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
1
,
2
)
hidden_states
,
gate
=
projected_states
.
chunk
(
2
,
dim
=
1
)
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
hidden_states
,
gate
=
projected_states
.
chunk
(
2
,
dim
=
-
2
)
# 2. Convolution sequence transformation
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
))
if
cache_params
is
not
None
and
not
cache_params
.
is_prompt
:
hidden_states
=
causal_conv1d_update
(
hidden_states
.
squeeze
(
-
1
),
cache_params
.
conv_state
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
activation
,
)
hidden_states
=
hidden_states
.
unsqueeze
(
-
1
)
else
:
if
cache_params
is
not
None
:
conv_states
=
nn
.
functional
.
pad
(
hidden_states
,
(
self
.
conv_kernel_size
-
hidden_states
.
shape
[
-
1
],
0
))
cache_params
.
conv_state
.
copy_
(
conv_states
)
hidden_states
,
_
=
causal_conv1d_fn
(
if
attn_metadata
.
query_start_loc
is
not
None
\
and
attn_metadata
.
context_lens_tensor
is
not
None
:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states
=
causal_conv1d_fn
(
hidden_states
,
conv_weights
,
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
)
conv_states
=
mamba_cache_params
.
conv_state
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
hidden_states
=
causal_conv1d_update
(
hidden_states
.
transpose
(
0
,
1
),
mamba_cache_params
.
conv_state
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
activation
,
conv_state_indices
=
mamba_cache_params
.
state_indices_tensor
)
hidden_states
=
hidden_states
.
transpose
(
0
,
1
)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters
=
self
.
x_proj
(
hidden_states
.
transpose
(
1
,
2
))[
0
]
ssm_parameters
=
self
.
x_proj
(
hidden_states
.
transpose
(
-
2
,
-
1
))[
0
]
time_step
,
B
,
C
=
torch
.
split
(
ssm_parameters
,
...
...
@@ -184,72 +172,47 @@ class JambaMambaMixer(nn.Module):
B
=
self
.
b_layernorm
(
B
.
contiguous
())
C
=
self
.
c_layernorm
(
C
.
contiguous
())
discrete_time_step
=
self
.
dt_proj
(
time_step
)[
0
].
transpose
(
1
,
2
)
discrete_time_step
=
self
.
dt_proj
(
time_step
)[
0
].
transpose
(
-
2
,
-
1
)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias
=
(
self
.
dt_proj
.
bias
.
float
()
if
hasattr
(
self
.
dt_proj
,
"bias"
)
else
None
)
if
cache_params
is
not
None
and
not
cache_params
.
is_prompt
:
scan_outputs
=
selective_state_update
(
cache_params
.
ssm_state
,
hidden_states
[...,
0
],
discrete_time_step
[...,
0
],
self
.
A
,
B
[:,
0
],
C
[:,
0
],
self
.
D
,
gate
[...,
0
],
time_proj_bias
,
dt_softplus
=
True
,
).
unsqueeze
(
-
1
)
else
:
scan_outputs
,
ssm_state
=
selective_scan_fn
(
if
attn_metadata
.
query_start_loc
is
not
None
\
and
attn_metadata
.
context_lens_tensor
is
not
None
:
scan_outputs
=
selective_scan_fn
(
hidden_states
,
mamba_cache_params
.
ssm_state
,
discrete_time_step
,
self
.
A
,
B
.
transpose
(
1
,
2
),
C
.
transpose
(
1
,
2
),
B
.
transpose
(
-
2
,
-
1
),
C
.
transpose
(
-
2
,
-
1
),
self
.
D
.
float
(),
gate
,
time_proj_bias
,
delta_softplus
=
True
,
return_last_state
=
True
,
)
if
ssm_state
is
not
None
and
cache_params
is
not
None
:
cache_params
.
ssm_state
.
copy_
(
ssm_state
)
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
scan_outputs
=
selective_state_update
(
mamba_cache_params
.
ssm_state
,
hidden_states
.
transpose
(
0
,
1
),
discrete_time_step
.
transpose
(
0
,
1
),
self
.
A
,
B
,
C
,
self
.
D
,
gate
.
transpose
(
0
,
1
),
time_proj_bias
,
dt_softplus
=
True
,
state_batch_indices
=
mamba_cache_params
.
state_indices_tensor
)
scan_outputs
=
scan_outputs
.
transpose
(
0
,
1
)
# 4. Final linear projection
contextualized_states
=
self
.
out_proj
(
scan_outputs
.
transpose
(
1
,
2
))[
0
]
contextualized_states
=
self
.
out_proj
(
scan_outputs
.
transpose
(
-
2
,
-
1
))[
0
]
return
contextualized_states
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
,
):
if
attn_metadata
.
prefill_metadata
is
not
None
:
offset
=
0
for
i
,
prompt_len
in
enumerate
(
attn_metadata
.
prefill_metadata
.
seq_lens
):
cache
=
MambaCacheParams
(
True
,
conv_state
=
conv_state
[
i
].
unsqueeze
(
0
),
ssm_state
=
ssm_state
[
i
].
unsqueeze
(
0
))
hidden_states
[
offset
:
offset
+
prompt_len
].
copy_
(
self
.
mamba_forward
(
hidden_states
[
offset
:
offset
+
prompt_len
].
unsqueeze
(
0
),
cache_params
=
cache
)[
0
])
offset
+=
prompt_len
else
:
cache
=
MambaCacheParams
(
False
,
conv_state
=
conv_state
,
ssm_state
=
ssm_state
)
hidden_states
=
self
.
mamba_forward
(
hidden_states
.
unsqueeze
(
1
),
cache_params
=
cache
)
hidden_states
=
hidden_states
.
squeeze
(
1
)
return
hidden_states
class
JambaMoE
(
nn
.
Module
):
...
...
@@ -323,7 +286,7 @@ class JambaMambaDecoderLayer(nn.Module):
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
config
=
config
self
.
mamba
=
JambaMambaMixer
(
config
,
layer_idx
)
self
.
mamba
=
JambaMambaMixer
(
config
)
num_experts
=
config
.
layers_num_experts
[
layer_idx
]
ffn_layer_class
=
JambaMoE
if
num_experts
>
1
else
JambaMLP
...
...
@@ -338,8 +301,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
**
kwargs
,
):
if
residual
is
None
:
...
...
@@ -349,8 +311,8 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mamba
(
hidden_states
,
attn_metadata
,
conv_state
,
ssm_state
)
hidden_states
=
self
.
mamba
(
hidden_states
,
attn_metadata
,
mamba_cache_params
)
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
hidden_states
,
residual
)
...
...
@@ -507,17 +469,14 @@ class JambaModel(nn.Module):
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
kv_cache
=
None
current_ssm_state
=
None
current_conv_state
=
None
layer_mamba_cache_params
=
None
if
isinstance
(
layer
,
JambaAttentionDecoderLayer
):
kv_cache
=
kv_caches
[(
i
-
self
.
config
.
attn_layer_offset
)
//
self
.
config
.
attn_layer_period
]
...
...
@@ -525,8 +484,8 @@ class JambaModel(nn.Module):
current_state_layer
=
i
-
(
1
+
(
i
-
self
.
config
.
attn_layer_offset
)
//
self
.
config
.
attn_layer_period
)
current_ssm_state
=
ssm_state
[
current_st
at
e
_layer
]
current_conv_state
=
conv_state
[
current_state_layer
]
layer_mamba_cache_params
=
mamba_cache_params
.
at_layer
_idx
(
current_state_layer
)
hidden_states
,
residual
=
layer
(
positions
=
positions
,
...
...
@@ -534,9 +493,7 @@ class JambaModel(nn.Module):
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
residual
=
residual
,
conv_state
=
current_conv_state
,
ssm_state
=
current_ssm_state
,
)
mamba_cache_params
=
layer_mamba_cache_params
)
hidden_states
,
_
=
self
.
final_layernorm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -571,8 +528,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
scheduler_config
:
Optional
[
SchedulerConfig
]
=
None
,
)
->
None
:
assert
not
scheduler_config
.
chunked_prefill_enabled
,
\
"Jamba currently does not support chunked prefill"
assert
not
cache_config
.
enable_prefix_caching
,
\
"Jamba currently does not support prefix caching"
...
...
@@ -596,10 +551,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
)
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
=
tuple
()
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self
.
mamba_cache_indices_mapping
:
Dict
[
str
,
Dict
[
int
,
int
]]
=
{}
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -611,242 +564,51 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
):
if
not
self
.
mamba_cache
:
self
.
_prepare_mamba_cache
()
if
"seqlen_agnostic_capture_inputs"
not
in
kwargs
:
# We get here only on Prefill/Eager mode runs
assert
all
(
key
in
kwargs
for
key
in
[
"request_ids_to_seq_ids"
,
"finished_requests_ids"
])
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
self
.
_release_mamba_cache
(
finished_requests_ids
)
batch_size
=
input_ids
.
shape
[
0
]
if
attn_metadata
.
prefill_metadata
:
batch_size
=
len
(
request_ids_to_seq_ids
)
mamba_cache
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
batch_size
,
finished_requests_ids
)
else
:
# CUDA graph capturing runs
mamba_cache
=
kwargs
[
"seqlen_agnostic_capture_inputs"
]
if
self
.
mamba_cache
is
None
:
max_batch_size
=
(
_get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
max
(
_BATCH_SIZES_TO_CAPTURE
)
+
2
)
layers_type
=
self
.
config
.
layers_block_type
num_mamba_layers
=
sum
(
[
layer_type
==
"mamba"
for
layer_type
in
layers_type
])
self
.
mamba_cache
=
MambaCacheManager
(
self
.
lm_head
.
weight
.
dtype
,
num_mamba_layers
,
max_batch_size
,
*
self
.
_get_mamba_cache_shape
())
(
mamba_cache_tensors
,
state_indices_tensor
,
)
=
self
.
mamba_cache
.
current_run_tensors
(
input_ids
,
attn_metadata
,
**
kwargs
)
mamba_cache_params
=
MambaCacheParams
(
mamba_cache_tensors
[
0
],
mamba_cache_tensors
[
1
],
state_indices_tensor
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
mamba_cache
[
0
],
mamba_cache
[
1
])
attn_metadata
,
mamba_cache_params
)
return
hidden_states
def
_swap_mamba_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
assert
len
(
self
.
mamba_cache
)
>
0
for
cache_t
in
self
.
mamba_cache
:
cache_t
[:,
[
to_index
,
from_index
]]
=
\
cache_t
[:,
[
from_index
,
to_index
]]
def
_copy_mamba_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
assert
len
(
self
.
mamba_cache
)
>
0
for
cache_t
in
self
.
mamba_cache
:
cache_t
[:,
to_index
].
copy_
(
cache_t
[:,
from_index
],
non_blocking
=
True
)
def
_move_out_if_already_occupied
(
self
,
index
:
int
,
all_occupied_indices
:
List
[
int
]):
if
index
in
all_occupied_indices
:
first_free_index
=
self
.
_first_free_index_in_mamba_cache
()
# In case occupied, move the occupied to a new empty block
self
.
_move_cache_index_and_mappings
(
from_index
=
index
,
to_index
=
first_free_index
)
def
_assign_seq_id_to_mamba_cache_in_specific_dest
(
self
,
cur_rid
:
str
,
seq_id
:
int
,
destination_index
:
int
):
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
all_occupied_indices
=
self
.
_get_all_occupied_indices
()
if
cur_rid
not
in
self
.
mamba_cache_indices_mapping
:
self
.
_move_out_if_already_occupied
(
index
=
destination_index
,
all_occupied_indices
=
all_occupied_indices
)
self
.
mamba_cache_indices_mapping
[
cur_rid
]
=
{
seq_id
:
destination_index
}
elif
seq_id
not
in
(
seq_ids2indices
:
=
self
.
mamba_cache_indices_mapping
[
cur_rid
]):
# parallel sampling , where n > 1, assume prefill have
# already happened now we only need to copy the already
# existing cache into the siblings seq_ids caches
self
.
_move_out_if_already_occupied
(
index
=
destination_index
,
all_occupied_indices
=
all_occupied_indices
)
index_exists
=
list
(
seq_ids2indices
.
values
())[
0
]
# case of decoding n>1, copy prefill cache to decoding indices
self
.
_copy_mamba_cache
(
from_index
=
index_exists
,
to_index
=
destination_index
)
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
=
destination_index
else
:
# already exists
cache_index_already_exists
=
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
if
cache_index_already_exists
!=
destination_index
:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self
.
_swap_pair_indices_and_mappings
(
from_index
=
cache_index_already_exists
,
to_index
=
destination_index
)
def
_prepare_current_run_mamba_cache
(
self
,
request_ids_to_seq_ids
:
Dict
[
str
,
list
[
int
]],
batch_size
:
int
,
finished_requests_ids
:
List
[
str
]):
running_indices
=
[]
request_ids_to_seq_ids_flatten
=
[
(
req_id
,
seq_id
)
for
req_id
,
seq_ids
in
request_ids_to_seq_ids
.
items
()
for
seq_id
in
seq_ids
]
for
dest_index
,
(
request_id
,
seq_id
)
in
enumerate
(
request_ids_to_seq_ids_flatten
):
if
request_id
in
finished_requests_ids
:
# Do not allocate cache index for requests that run
# and finish right after
continue
self
.
_assign_seq_id_to_mamba_cache_in_specific_dest
(
request_id
,
seq_id
,
dest_index
)
running_indices
.
append
(
dest_index
)
self
.
_clean_up_first_bs_blocks
(
batch_size
,
running_indices
)
conv_state
=
self
.
mamba_cache
[
0
][:,
:
batch_size
]
temporal_state
=
self
.
mamba_cache
[
1
][:,
:
batch_size
]
return
(
conv_state
,
temporal_state
)
def
_get_all_occupied_indices
(
self
):
return
[
cache_idx
for
seq_ids2indices
in
self
.
mamba_cache_indices_mapping
.
values
()
for
cache_idx
in
seq_ids2indices
.
values
()
]
def
_clean_up_first_bs_blocks
(
self
,
batch_size
:
int
,
indices_for_current_run
:
List
[
int
]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices
=
range
(
batch_size
)
max_possible_batch_size
=
self
.
mamba_cache
[
0
].
shape
[
1
]
for
destination_index
in
destination_indices
:
if
destination_index
in
self
.
_get_all_occupied_indices
()
and
\
destination_index
not
in
indices_for_current_run
:
# move not running indices outside of the batch
all_other_indices
=
list
(
range
(
batch_size
,
max_possible_batch_size
))
first_avail_index
=
self
.
_first_free_index_in_mamba_cache
(
all_other_indices
)
self
.
_swap_indices
(
from_index
=
destination_index
,
to_index
=
first_avail_index
)
def
_move_cache_index_and_mappings
(
self
,
from_index
:
int
,
to_index
:
int
):
self
.
_copy_mamba_cache
(
from_index
=
from_index
,
to_index
=
to_index
)
self
.
_update_mapping_index
(
from_index
=
from_index
,
to_index
=
to_index
)
def
_swap_pair_indices_and_mappings
(
self
,
from_index
:
int
,
to_index
:
int
):
self
.
_swap_mamba_cache
(
from_index
=
from_index
,
to_index
=
to_index
)
self
.
_swap_mapping_index
(
from_index
=
from_index
,
to_index
=
to_index
)
def
_swap_mapping_index
(
self
,
from_index
:
int
,
to_index
:
int
):
for
seq_ids2index
in
self
.
mamba_cache_indices_mapping
.
values
():
for
seq_id
,
index
in
seq_ids2index
.
items
():
if
from_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
to_index
})
elif
to_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
from_index
})
def
_update_mapping_index
(
self
,
from_index
:
int
,
to_index
:
int
):
for
seq_ids2index
in
self
.
mamba_cache_indices_mapping
.
values
():
for
seq_id
,
index
in
seq_ids2index
.
items
():
if
from_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
to_index
})
return
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
"""
Copy the relevant Mamba cache into the CUDA graph input buffer
that was provided during the capture runs
(JambaForCausalLM.mamba_gc_cache_buffer).
"""
assert
all
(
key
in
kwargs
for
key
in
[
"request_ids_to_seq_ids"
,
"finished_requests_ids"
])
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
self
.
_release_mamba_cache
(
finished_requests_ids
)
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
cg_batch_size
=
input_buffers
[
'input_ids'
].
shape
[
0
]
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
cg_batch_size
,
finished_requests_ids
)
return
self
.
mamba_cache
.
copy_inputs_before_cuda_graphs
(
input_buffers
,
**
kwargs
)
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return
tuple
(
buffer
[:,
:
batch_size
]
for
buffer
in
self
.
mamba_cache
)
def
_release_mamba_cache
(
self
,
finished_seq_groups_req_ids
:
List
[
str
]):
for
req_id
in
finished_seq_groups_req_ids
:
if
req_id
in
self
.
mamba_cache_indices_mapping
:
self
.
mamba_cache_indices_mapping
.
pop
(
req_id
)
def
_first_free_index_in_mamba_cache
(
self
,
indices_range
:
Optional
[
List
[
int
]]
=
None
)
->
int
:
assert
self
.
mamba_cache
is
not
None
if
indices_range
is
None
:
max_possible_batch_size
=
self
.
mamba_cache
[
0
].
shape
[
1
]
indices_range
=
list
(
range
(
max_possible_batch_size
))
all_occupied_indices
=
self
.
_get_all_occupied_indices
()
for
i
in
indices_range
:
if
i
not
in
all_occupied_indices
:
return
i
raise
Exception
(
"Couldn't find a free spot in the mamba cache! This"
"should never happen"
)
return
self
.
mamba_cache
.
get_seqlen_agnostic_capture_inputs
(
batch_size
)
def
_get_mamba_cache_shape
(
self
)
->
Tuple
[
Optional
[
Tuple
[
int
,
int
]],
Optional
[
Tuple
[
int
,
int
]]]:
self
)
->
Tuple
[
Tuple
[
int
,
int
],
Tuple
[
int
,
int
]]:
world_size
=
get_tensor_model_parallel_world_size
()
hidden_size
=
self
.
config
.
hidden_size
conv_state_shape
=
(
self
.
config
.
mamba_expand
*
hidden_size
//
world_size
,
self
.
config
.
mamba_d_conv
,
self
.
config
.
mamba_d_conv
-
1
,
)
temporal_state_shape
=
(
self
.
config
.
mamba_expand
*
self
.
config
.
hidden_size
//
world_size
,
self
.
config
.
mamba_expand
*
hidden_size
//
world_size
,
self
.
config
.
mamba_d_state
,
)
return
conv_state_shape
,
temporal_state_shape
def
_prepare_mamba_cache
(
self
):
dtype
=
self
.
lm_head
.
weight
.
dtype
layers_type
=
self
.
config
.
layers_block_type
mamba_layers
=
sum
(
[
layer_type
==
"mamba"
for
layer_type
in
layers_type
])
max_batch_size
=
(
_get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
max
(
_BATCH_SIZES_TO_CAPTURE
)
+
2
)
conv_state_shape
,
temporal_state_shape
=
self
.
_get_mamba_cache_shape
()
assert
conv_state_shape
is
not
None
and
temporal_state_shape
is
not
None
self
.
mamba_cache
=
(
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
conv_state_shape
,
dtype
=
dtype
,
device
=
"cuda"
),
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
temporal_state_shape
,
dtype
=
dtype
,
device
=
"cuda"
))
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/llama.py
View file @
6d2051cc
...
...
@@ -30,6 +30,7 @@ import os
import
re
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
...
...
@@ -39,8 +40,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.
quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.
pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
get_compressed_tensors_cache_scale
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
...
...
@@ -49,12 +50,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.utils
import
is_hip
from
.interfaces
import
SupportsLoRA
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
...
...
@@ -77,12 +80,15 @@ class LlamaMLP(nn.Module):
output_sizes
=
[
intermediate_size
]
*
2
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
self
.
down_proj
=
RowParallelLinear
(
input_size
=
intermediate_size
,
output_size
=
hidden_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
input_size
=
intermediate_size
,
output_size
=
hidden_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
...
...
@@ -166,12 +172,15 @@ class LlamaAttention(nn.Module):
rope_scaling
=
rope_scaling
,
is_neox_style
=
is_neox_style
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
...
...
@@ -260,12 +269,10 @@ class LlamaDecoderLayer(nn.Module):
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
...
...
@@ -274,6 +281,13 @@ class LlamaDecoderLayer(nn.Module):
return
hidden_states
,
residual
@
support_torch_compile
(
dynamic_arg_dims
=
{
"input_ids"
:
0
,
"positions"
:
0
,
"inputs_embeds"
:
0
,
"intermediate_tensors"
:
0
,
})
class
LlamaModel
(
nn
.
Module
):
def
__init__
(
...
...
@@ -307,12 +321,27 @@ class LlamaModel(nn.Module):
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
prefix
=
f
"
{
prefix
}
.layers"
,
)
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
...
...
@@ -338,13 +367,9 @@ class LlamaModel(nn.Module):
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
...
...
@@ -355,153 +380,6 @@ class LlamaModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
LlamaForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
"embed_tokens"
,
"lm_head"
]
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[
"lm_head"
]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping
=
{
"layers"
:
"model.layers"
,
"attention"
:
"self_attn"
,
"wq"
:
"q_proj"
,
"wk"
:
"k_proj"
,
"wv"
:
"v_proj"
,
"wo"
:
"o_proj"
,
"attention_norm"
:
"input_layernorm"
,
"feed_forward"
:
"mlp"
,
"w1"
:
"gate_proj"
,
"w2"
:
"down_proj"
,
"w3"
:
"up_proj"
,
"ffn_norm"
:
"post_attention_layernorm"
,
"tok_embeddings"
:
"model.embed_tokens"
,
"output"
:
"lm_head"
,
"norm"
:
"model.norm"
}
def
__init__
(
self
,
config
:
LlamaConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
model
=
LlamaModel
(
config
,
cache_config
,
quant_config
,
lora_config
=
lora_config
,
prefix
=
"model"
)
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
model_output
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
return
model_output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
"residual"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
...
...
@@ -513,8 +391,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
name
,
loaded_weight
=
self
.
maybe_remap_mistral
(
name
,
loaded_weight
)
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
...
...
@@ -522,11 +398,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
if
scale_name
:
=
get_compressed_tensors_cache_scale
(
name
):
# Loading kv cache scales for compressed-tensors quantization
param
=
params_dict
[
scale_name
]
...
...
@@ -535,7 +406,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
loaded_weight
=
loaded_weight
[
0
]
weight_loader
(
param
,
loaded_weight
)
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
...
...
@@ -566,7 +437,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
lay_key_words
=
[
...
...
@@ -574,7 +445,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"self_attn.o_proj.weight"
,
"mlp.gate_up_proj.weight"
,
"mlp.down_proj.weight"
,
"lm_head.weight"
#
"lm_head.weight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
...
...
@@ -656,7 +527,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
k
=
weight_data
.
shape
[
0
]
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
k
,
-
1
)
weight_data
.
data
.
copy_
(
_weight
)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
...
...
@@ -667,8 +538,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
quantization_param_path
,
tp_rank
,
tp_size
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
__class__
.
model_type
):
if
not
isinstance
(
self
.
model
.
layers
[
layer_idx
],
nn
.
Identity
):
layer_self_attn
=
self
.
model
.
layers
[
layer_idx
].
self_attn
if
not
isinstance
(
self
.
layers
[
layer_idx
],
nn
.
Identity
):
layer_self_attn
=
self
.
layers
[
layer_idx
].
self_attn
if
is_hip
():
# The scaling factor convention we are assuming is
...
...
@@ -682,13 +553,161 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
raise
RuntimeError
(
"Self attention has no KV cache scaling "
"factor attribute!"
)
class
LlamaForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
"embed_tokens"
,
"lm_head"
]
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
}
embedding_padding_modules
=
[
"lm_head"
]
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules
=
[
".down_proj."
,
".o_proj."
]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping
=
{
"layers"
:
"model.layers"
,
"attention"
:
"self_attn"
,
"wq"
:
"q_proj"
,
"wk"
:
"k_proj"
,
"wv"
:
"v_proj"
,
"wo"
:
"o_proj"
,
"attention_norm"
:
"input_layernorm"
,
"feed_forward"
:
"mlp"
,
"w1"
:
"gate_proj"
,
"w2"
:
"down_proj"
,
"w3"
:
"up_proj"
,
"ffn_norm"
:
"post_attention_layernorm"
,
"tok_embeddings"
:
"model.embed_tokens"
,
"output"
:
"lm_head"
,
"norm"
:
"model.norm"
}
def
__init__
(
self
,
config
:
LlamaConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
model
=
LlamaModel
(
config
,
cache_config
,
quant_config
,
lora_config
=
lora_config
,
prefix
=
"model"
)
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
),
quant_config
=
quant_config
,
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
lm_head
.
tie_weights
(
self
.
model
.
embed_tokens
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
model_output
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
return
model_output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
loader
.
load_weights
(
self
.
maybe_remap_mistral
(
name
,
loaded_weight
)
for
name
,
loaded_weight
in
weights
)
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
self
.
model
.
load_kv_cache_scales
(
quantization_param_path
)
# This function is used to remap the mistral format as
# used by Mistral and Llama <=2
def
maybe_remap_mistral
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
)
->
Tuple
[
str
,
torch
.
Tensor
]:
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
,
)
->
Tuple
[
str
,
torch
.
Tensor
]:
def
permute
(
w
,
n_heads
):
def
permute
(
w
:
torch
.
Tensor
,
n_heads
:
int
):
attn_in
=
self
.
config
.
head_dim
*
n_heads
attn_out
=
self
.
config
.
hidden_size
...
...
@@ -711,3 +730,52 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
name
=
name
.
replace
(
item
,
mapping
[
item
])
return
name
,
loaded_weight
class
LlamaEmbeddingModel
(
nn
.
Module
,
SupportsPP
):
"""
A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def
__init__
(
self
,
**
kwargs
,
)
->
None
:
super
().
__init__
()
self
.
model
=
LlamaModel
(
**
kwargs
)
self
.
_pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
return
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
.
model
.
load_weights
(
weights
)
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
self
.
model
.
load_kv_cache_scales
(
quantization_param_path
)
vllm/model_executor/models/llama_embedding.py
deleted
100644 → 0
View file @
2c7f740a
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
vllm.attention
import
AttentionMetadata
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
PoolerOutput
class
LlamaEmbeddingModel
(
nn
.
Module
):
"""A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def
__init__
(
self
,
**
kwargs
,
)
->
None
:
super
().
__init__
()
self
.
model
=
LlamaModel
(
**
kwargs
)
self
.
_pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
model
.
forward
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
inputs_embeds
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
model
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/llava.py
View file @
6d2051cc
from
functools
import
cached_property
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
...
...
@@ -8,11 +9,10 @@ from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.inputs
import
INPUT_REGISTRY
,
DecoderOnlyInputs
,
InputContext
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -21,12 +21,12 @@ from vllm.utils import is_list_of
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_max_clip_image_tokens
,
input_processor_for_clip
)
from
.interfaces
import
SupportsMultiModal
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
,
input_processor_for_siglip
)
from
.utils
import
(
flatten_bn
,
group_weights_with_prefix
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
class
LlavaImagePixelInputs
(
TypedDict
):
...
...
@@ -125,10 +125,10 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
raise
NotImplementedError
(
msg
)
def
input_processor_for_llava
(
ctx
:
InputContext
,
llm_
inputs
:
LLM
Inputs
):
multi_modal_data
=
llm_
inputs
.
get
(
"multi_modal_data"
)
def
input_processor_for_llava
(
ctx
:
InputContext
,
inputs
:
DecoderOnly
Inputs
):
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
llm_
inputs
return
inputs
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
(
LlavaConfig
)
...
...
@@ -151,7 +151,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
return
input_processor_for_clip
(
model_config
,
vision_config
,
llm_
inputs
,
inputs
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
...
...
@@ -159,7 +159,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
return
input_processor_for_siglip
(
model_config
,
vision_config
,
llm_
inputs
,
inputs
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
...
...
@@ -198,7 +198,7 @@ def _init_vision_tower(hf_config: LlavaConfig):
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_llava_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_llava
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_llava
)
class
LlavaForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
class
LlavaForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
def
__init__
(
self
,
config
:
LlavaConfig
,
...
...
@@ -220,6 +220,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
@
cached_property
def
sampler
(
self
):
if
hasattr
(
self
.
language_model
,
"sampler"
):
return
self
.
language_model
.
sampler
return
Sampler
()
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h
=
w
=
self
.
config
.
vision_config
.
image_size
expected_dims
=
(
3
,
h
,
w
)
...
...
@@ -315,7 +325,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
object
,
)
->
SamplerOutput
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
"""Run forward pass for LLaVA-1.5.
One key thing to understand is the `input_ids` already accounts for the
...
...
@@ -351,26 +361,32 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
See also:
:class:`LlavaImageInputs`
"""
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
config
.
image_token_index
)
if
intermediate_tensors
is
not
None
:
input_ids
=
None
else
:
inputs_embeds
=
None
else
:
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
config
.
image_token_index
)
else
:
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
input_ids
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
None
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
...
...
@@ -391,19 +407,5 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators for components
weights_group
=
group_weights_with_prefix
(
weights
)
# load vision encoder
self
.
vision_tower
.
load_weights
(
weights_group
[
"vision_tower"
])
# load mlp projector
mlp_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
weights_group
[
"multi_modal_projector"
]:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
self
.
language_model
.
load_weights
(
weights_group
[
"language_model"
])
loader
=
AutoWeightsLoader
(
self
)
loader
.
load_weights
(
weights
)
vllm/model_executor/models/llava_next.py
View file @
6d2051cc
from
functools
import
cached_property
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
...
...
@@ -11,10 +12,9 @@ from typing_extensions import NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.inputs
import
INPUT_REGISTRY
,
DecoderOnlyInputs
,
InputContext
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -23,13 +23,13 @@ from vllm.utils import is_list_of
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_clip_image_feature_size
,
get_clip_patch_grid_length
,
input_processor_for_clip
)
from
.interfaces
import
SupportsMultiModal
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.llava
import
LlavaMultiModalProjector
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_siglip_image_feature_size
,
get_siglip_patch_grid_length
,
input_processor_for_siglip
)
from
.utils
import
(
flatten_bn
,
group_weights_with_prefix
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
=
448
...
...
@@ -201,10 +201,11 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
raise
NotImplementedError
(
msg
)
def
input_processor_for_llava_next
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
def
input_processor_for_llava_next
(
ctx
:
InputContext
,
inputs
:
DecoderOnlyInputs
):
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
llm_
inputs
return
inputs
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
(
LlavaNextConfig
)
...
...
@@ -239,7 +240,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
return
input_processor_for_clip
(
model_config
,
vision_config
,
llm_
inputs
,
inputs
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
...
...
@@ -247,7 +248,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
return
input_processor_for_siglip
(
model_config
,
vision_config
,
llm_
inputs
,
inputs
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
...
...
@@ -286,7 +287,8 @@ def _init_vision_tower(hf_config: LlavaNextConfig):
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_llava_next_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_llava_next
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_llava_next
)
class
LlavaNextForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
class
LlavaNextForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
def
__init__
(
self
,
config
:
LlavaNextConfig
,
...
...
@@ -300,6 +302,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
# TODO: Optionally initializes this for supporting embeddings.
self
.
vision_tower
=
_init_vision_tower
(
config
)
self
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
))
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
...
...
@@ -308,8 +312,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
))
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
@
cached_property
def
sampler
(
self
):
if
hasattr
(
self
.
language_model
,
"sampler"
):
return
self
.
language_model
.
sampler
return
Sampler
()
def
_validate_image_sizes
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
expected_dims
=
(
2
,
)
...
...
@@ -542,7 +553,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
object
,
)
->
SamplerOutput
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
"""Run forward pass for LlaVA-NeXT.
One key thing to understand is the `input_ids` already accounts for the
...
...
@@ -587,26 +598,30 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
See also:
:class:`LlavaNextImageInputs`
"""
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
intermediate_tensors
is
not
None
:
input_ids
=
None
inputs_embeds
=
None
else
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
if
image_input
is
not
None
:
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
config
.
image_token_index
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
config
.
image_token_index
)
input_ids
=
None
else
:
inputs_embeds
=
None
input_ids
=
None
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
None
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
...
...
@@ -627,27 +642,5 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators for components
weights_group
=
group_weights_with_prefix
(
weights
)
# load vision encoder
self
.
vision_tower
.
load_weights
(
weights_group
[
"vision_tower"
])
# load mlp projector
mlp_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
weights_group
[
"multi_modal_projector"
]:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load newline
for
name
,
loaded_weight
in
weights_group
[
"image_newline"
]:
assert
name
==
""
param
=
self
.
image_newline
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
self
.
language_model
.
load_weights
(
weights_group
[
"language_model"
])
loader
=
AutoWeightsLoader
(
self
)
loader
.
load_weights
(
weights
)
vllm/model_executor/models/llava_next_video.py
View file @
6d2051cc
import
math
from
functools
import
cached_property
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
...
...
@@ -10,12 +11,11 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
InputContext
,
token_inputs
)
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.models.clip
import
CLIPVisionModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
...
@@ -25,10 +25,10 @@ from vllm.sequence import IntermediateTensors
from
vllm.utils
import
is_list_of
from
.clip
import
dummy_image_for_clip
,
dummy_seq_data_for_clip
from
.interfaces
import
SupportsMultiModal
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
)
from
.utils
import
(
group_weights_with_prefix
,
init_vllm_registered_model
,
from
.utils
import
(
AutoWeightsLoader
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
# For profile run
...
...
@@ -140,10 +140,10 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int,
def
input_processor_for_llava_next_video
(
ctx
:
InputContext
,
llm_
inputs
:
LLM
Inputs
):
multi_modal_data
=
llm_
inputs
.
get
(
"multi_modal_data"
)
inputs
:
DecoderOnly
Inputs
):
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"video"
not
in
multi_modal_data
:
return
llm_
inputs
return
inputs
video_data
=
multi_modal_data
[
"video"
]
model_config
=
ctx
.
model_config
...
...
@@ -161,15 +161,15 @@ def input_processor_for_llava_next_video(ctx: InputContext,
new_prompt
,
new_token_ids
=
repeat_and_pad_placeholder_tokens
(
tokenizer
,
llm_
inputs
.
get
(
"prompt"
),
llm_
inputs
[
"prompt_token_ids"
],
inputs
.
get
(
"prompt"
),
inputs
[
"prompt_token_ids"
],
placeholder_token_id
=
hf_config
.
video_token_index
,
repeat_count
=
video_feature_size
,
)
return
LLMI
nputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
multi_modal_data
=
multi_modal_data
)
return
token_i
nputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
multi_modal_data
=
multi_modal_data
)
elif
is_list_of
(
video_data
,
np
.
ndarray
):
raise
NotImplementedError
(
...
...
@@ -267,7 +267,8 @@ class LlavaNextMultiModalProjector(nn.Module):
"video"
,
get_max_llava_next_video_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_llava_next_video
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_llava_next_video
)
class
LlavaNextVideoForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
class
LlavaNextVideoForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
def
__init__
(
self
,
config
:
LlavaNextVideoConfig
,
...
...
@@ -281,13 +282,23 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
# Initialize the vision tower only up to the required feature layer
self
.
vision_tower
=
_init_vision_tower
(
config
)
self
.
vision_resampler
=
LlavaNextVideoPooler
(
config
)
self
.
multi_modal_projector
=
LlavaNextMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
vision_resampler
=
LlavaNextVideoPooler
(
config
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
model
.
make_empty_intermediate_tensors
)
@
cached_property
def
sampler
(
self
):
if
hasattr
(
self
.
language_model
,
"sampler"
):
return
self
.
language_model
.
sampler
return
Sampler
()
def
_validate_video_pixel_values
(
self
,
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
...
...
@@ -397,34 +408,36 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
object
,
)
->
SamplerOutput
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
"""Run forward pass for LlaVA-NeXT-Video.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values_videos: Pixels in each frames for each input videos.
"""
video_input
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
# merge video embeddings into input embeddings
if
video_input
is
not
None
:
video_embeddings
=
self
.
_process_video_pixels
(
video_input
)
inputs_embeds
=
self
.
language_model
\
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
video_embeddings
,
self
.
config
.
video_token_index
)
if
intermediate_tensors
is
not
None
:
input_ids
=
None
else
:
inputs_embeds
=
None
else
:
video_input
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
if
video_input
is
not
None
:
video_embeddings
=
self
.
_process_video_pixels
(
video_input
)
inputs_embeds
=
self
.
language_model
\
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
video_embeddings
,
self
.
config
.
video_token_index
)
input_ids
=
None
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
None
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
...
...
@@ -445,19 +458,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators for components
weights_group
=
group_weights_with_prefix
(
weights
)
# load vision encoder
self
.
vision_tower
.
load_weights
(
weights_group
[
"vision_tower"
])
# load mlp projector
mlp_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
weights_group
[
"multi_modal_projector"
]:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
self
.
language_model
.
load_weights
(
weights_group
[
"language_model"
])
loader
=
AutoWeightsLoader
(
self
,
# This model doesn't support images for now
ignore_unexpected_prefixes
=
[
"image_newline"
],
)
loader
.
load_weights
(
weights
)
vllm/model_executor/models/llava_onevision.py
View file @
6d2051cc
import
math
from
functools
import
cached_property
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
...
...
@@ -14,13 +15,11 @@ from typing_extensions import NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
InputContext
,
token_inputs
)
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
...
...
@@ -31,14 +30,12 @@ from vllm.utils import is_list_of
from
.clip
import
(
CLIPVisionModel
,
dummy_seq_data_for_clip
,
dummy_video_for_clip
,
get_clip_image_feature_size
,
get_clip_patch_grid_length
,
input_processor_for_clip
)
from
.interfaces
import
SupportsMultiModal
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.siglip
import
(
SiglipVisionModel
,
dummy_seq_data_for_siglip
,
dummy_video_for_siglip
,
get_siglip_image_feature_size
,
get_siglip_patch_grid_length
,
input_processor_for_siglip
)
from
.utils
import
(
flatten_bn
,
group_weights_with_prefix
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
logger
=
init_logger
(
__name__
)
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
=
448
...
...
@@ -253,10 +250,10 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
def
input_processor_when_multimodal_input_image
(
ctx
:
InputContext
,
llm_
inputs
:
LLM
Inputs
):
multi_modal_data
=
llm_
inputs
.
get
(
"multi_modal_data"
)
inputs
:
DecoderOnly
Inputs
):
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
llm_
inputs
return
inputs
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
(
LlavaOnevisionConfig
)
...
...
@@ -291,7 +288,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext,
return
input_processor_for_clip
(
model_config
,
vision_config
,
llm_
inputs
,
inputs
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
...
...
@@ -299,7 +296,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext,
return
input_processor_for_siglip
(
model_config
,
vision_config
,
llm_
inputs
,
inputs
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
...
...
@@ -309,10 +306,10 @@ def input_processor_when_multimodal_input_image(ctx: InputContext,
def
input_processor_when_multimodal_input_video
(
ctx
:
InputContext
,
llm_
inputs
:
LLM
Inputs
):
multi_modal_data
=
llm_
inputs
.
get
(
"multi_modal_data"
)
inputs
:
DecoderOnly
Inputs
):
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"video"
not
in
multi_modal_data
:
return
llm_
inputs
return
inputs
video_data
=
multi_modal_data
[
"video"
]
model_config
=
ctx
.
model_config
...
...
@@ -327,15 +324,15 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
new_prompt
,
new_token_ids
=
repeat_and_pad_placeholder_tokens
(
tokenizer
,
llm_
inputs
.
get
(
"prompt"
),
llm_
inputs
[
"prompt_token_ids"
],
inputs
.
get
(
"prompt"
),
inputs
[
"prompt_token_ids"
],
placeholder_token_id
=
hf_config
.
video_token_index
,
repeat_count
=
video_feature_size
,
)
return
LLMI
nputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
multi_modal_data
=
multi_modal_data
)
return
token_i
nputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
multi_modal_data
=
multi_modal_data
)
elif
is_list_of
(
video_data
,
np
.
ndarray
):
raise
NotImplementedError
(
...
...
@@ -346,15 +343,15 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
def
input_processor_for_llava_onevision
(
ctx
:
InputContext
,
llm_
inputs
:
LLM
Inputs
):
multi_modal_data
=
llm_
inputs
.
get
(
"multi_modal_data"
)
inputs
:
DecoderOnly
Inputs
):
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
(
"video"
not
in
multi_modal_data
and
"image"
not
in
multi_modal_data
):
return
llm_
inputs
return
inputs
if
"image"
in
multi_modal_data
:
return
input_processor_when_multimodal_input_image
(
ctx
,
llm_
inputs
)
return
input_processor_when_multimodal_input_image
(
ctx
,
inputs
)
if
"video"
in
multi_modal_data
:
return
input_processor_when_multimodal_input_video
(
ctx
,
llm_
inputs
)
return
input_processor_when_multimodal_input_video
(
ctx
,
inputs
)
msg
=
"Unsupported multi data type"
raise
NotImplementedError
(
msg
)
...
...
@@ -414,7 +411,8 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
"video"
,
get_max_llava_onevision_video_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_llava_onevision
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_llava_onevision
)
class
LlavaOnevisionForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
class
LlavaOnevisionForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
def
__init__
(
self
,
config
:
LlavaOnevisionConfig
,
...
...
@@ -434,6 +432,16 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
self
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
))
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
model
.
make_empty_intermediate_tensors
)
@
cached_property
def
sampler
(
self
):
if
hasattr
(
self
.
language_model
,
"sampler"
):
return
self
.
language_model
.
sampler
return
Sampler
()
def
_validate_image_sizes
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
expected_dims
=
(
2
,
)
...
...
@@ -805,39 +813,42 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
object
,
)
->
SamplerOutput
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
"""Run forward pass for LlaVA-Onevision.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values_videos: Pixels in each frames for each input videos.
"""
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
# merge video embeddings into input embeddings
if
modalities
:
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
if
"images"
in
modalities
:
image_input
=
modalities
[
"images"
]
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
config
.
image_token_index
)
if
"videos"
in
modalities
:
video_input
=
modalities
[
"videos"
]
video_embeddings
=
self
.
_process_video_pixels
(
video_input
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
video_embeddings
,
self
.
config
.
video_token_index
)
if
intermediate_tensors
is
not
None
:
input_ids
=
None
else
:
inputs_embeds
=
None
else
:
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
modalities
:
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
if
"images"
in
modalities
:
image_input
=
modalities
[
"images"
]
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
config
.
image_token_index
)
if
"videos"
in
modalities
:
video_input
=
modalities
[
"videos"
]
video_embeddings
=
self
.
_process_video_pixels
(
video_input
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
video_embeddings
,
self
.
config
.
video_token_index
)
input_ids
=
None
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
None
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
...
...
@@ -858,19 +869,5 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators for components
weights_group
=
group_weights_with_prefix
(
weights
)
# load vision encoder
self
.
vision_tower
.
load_weights
(
weights_group
[
"vision_tower"
])
# load mlp projector
mlp_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
weights_group
[
"multi_modal_projector"
]:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
self
.
language_model
.
load_weights
(
weights_group
[
"language_model"
])
loader
=
AutoWeightsLoader
(
self
)
loader
.
load_weights
(
weights
)
vllm/model_executor/models/mamba.py
0 → 100644
View file @
6d2051cc
# coding=utf-8
"""PyTorch MAMBA model."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
MambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_update
)
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
selective_scan_fn
,
selective_state_update
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
composed_weight_loader
,
default_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
IsAttentionFree
)
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
from
vllm.worker.model_runner
import
(
_BATCH_SIZES_TO_CAPTURE
,
_get_graph_batch_size
)
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class
MambaMixer
(
nn
.
Module
):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
for why A isn't selective) ∆, B, C are input-dependent
(this is a key difference between Mamba and the linear time
invariant S4, and is why Mamba is called
**selective** state spaces)
"""
def
__init__
(
self
,
config
:
MambaConfig
,
layer_idx
):
super
().
__init__
()
self
.
config
=
config
self
.
layer_idx
=
layer_idx
self
.
hidden_size
=
config
.
hidden_size
self
.
ssm_state_size
=
config
.
state_size
self
.
conv_kernel_size
=
config
.
conv_kernel
self
.
intermediate_size
=
config
.
intermediate_size
self
.
time_step_rank
=
int
(
config
.
time_step_rank
)
self
.
conv1d
=
ColumnParallelLinear
(
input_size
=
self
.
conv_kernel_size
,
output_size
=
self
.
intermediate_size
,
bias
=
config
.
use_conv_bias
,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
in_proj
=
MergedColumnParallelLinear
(
self
.
hidden_size
,
[
self
.
intermediate_size
]
*
2
,
bias
=
config
.
use_bias
)
# selective projection used to make dt, B and C input dependent
self
.
x_proj
=
RowParallelLinear
(
self
.
intermediate_size
,
self
.
time_step_rank
+
self
.
ssm_state_size
*
2
,
bias
=
False
,
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self
.
dt_proj
=
ColumnParallelLinear
(
self
.
time_step_rank
,
self
.
intermediate_size
,
bias
=
True
,
skip_bias_add
=
True
)
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
A
=
nn
.
Parameter
(
torch
.
empty
(
self
.
intermediate_size
//
tp_size
,
self
.
ssm_state_size
,
dtype
=
torch
.
float32
,
))
self
.
D
=
nn
.
Parameter
(
torch
.
ones
(
self
.
intermediate_size
//
tp_size
))
set_weight_attrs
(
self
.
D
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
a_weight_loader
=
composed_weight_loader
(
sharded_weight_loader
(
0
),
lambda
x
:
-
torch
.
exp
(
x
.
float
()))
set_weight_attrs
(
self
.
A
,
{
"weight_loader"
:
a_weight_loader
})
self
.
out_proj
=
RowParallelLinear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
config
.
use_bias
,
input_is_parallel
=
True
,
)
self
.
activation
=
config
.
hidden_act
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
):
# 1. Gated MLP's linear projection
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
hidden_states
,
gate
=
projected_states
.
chunk
(
2
,
dim
=-
2
)
# 2. Convolution sequence transformation
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
))
if
attn_metadata
.
query_start_loc
is
not
None
\
and
attn_metadata
.
context_lens_tensor
is
not
None
:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states
=
causal_conv1d_fn
(
hidden_states
,
conv_weights
,
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
conv_states
=
mamba_cache_params
.
conv_state
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
hidden_states
=
causal_conv1d_update
(
hidden_states
.
transpose
(
0
,
1
),
mamba_cache_params
.
conv_state
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
activation
,
conv_state_indices
=
mamba_cache_params
.
state_indices_tensor
)
hidden_states
=
hidden_states
.
transpose
(
0
,
1
)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters
=
self
.
x_proj
(
hidden_states
.
transpose
(
-
2
,
-
1
))[
0
]
time_step
,
B
,
C
=
torch
.
split
(
ssm_parameters
,
[
self
.
time_step_rank
,
self
.
ssm_state_size
,
self
.
ssm_state_size
],
dim
=-
1
,
)
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
discrete_time_step
=
self
.
dt_proj
(
time_step
)[
0
].
transpose
(
-
2
,
-
1
)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias
=
(
self
.
dt_proj
.
bias
.
float
()
if
hasattr
(
self
.
dt_proj
,
"bias"
)
else
None
)
if
attn_metadata
.
query_start_loc
is
not
None
\
and
attn_metadata
.
context_lens_tensor
is
not
None
:
scan_outputs
=
selective_scan_fn
(
hidden_states
,
mamba_cache_params
.
ssm_state
,
discrete_time_step
,
self
.
A
,
B
.
transpose
(
-
2
,
-
1
),
C
.
transpose
(
-
2
,
-
1
),
self
.
D
.
float
(),
gate
,
time_proj_bias
,
delta_softplus
=
True
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
scan_outputs
=
selective_state_update
(
mamba_cache_params
.
ssm_state
,
hidden_states
.
transpose
(
0
,
1
),
discrete_time_step
.
transpose
(
0
,
1
),
self
.
A
,
B
,
C
,
self
.
D
,
gate
.
transpose
(
0
,
1
),
time_proj_bias
,
dt_softplus
=
True
,
state_batch_indices
=
mamba_cache_params
.
state_indices_tensor
)
scan_outputs
=
scan_outputs
.
transpose
(
0
,
1
)
# 4. Final linear projection
contextualized_states
=
self
.
out_proj
(
scan_outputs
.
transpose
(
-
2
,
-
1
))[
0
]
return
contextualized_states
class
MambaDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MambaConfig
,
layer_idx
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
config
=
config
self
.
mixer
=
MambaMixer
(
config
,
layer_idx
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
pre_ff_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
**
kwargs
,
):
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
norm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mixer
(
hidden_states
,
attn_metadata
,
mamba_cache_params
)
return
hidden_states
,
residual
class
MambaModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MambaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
((
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
)
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
self
.
embeddings
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
decoder_layers
=
[]
for
i
in
range
(
config
.
num_hidden_layers
):
decoder_layers
.
append
(
MambaDecoderLayer
(
config
,
layer_idx
=
i
,
cache_config
=
cache_config
,
quant_config
=
quant_config
))
self
.
layers
=
nn
.
ModuleList
(
decoder_layers
)
self
.
norm_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embeddings
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
attn_metadata
=
attn_metadata
,
residual
=
residual
,
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
))
hidden_states
,
_
=
self
.
norm_f
(
hidden_states
,
residual
)
return
hidden_states
class
MambaForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsAttentionFree
):
def
__init__
(
self
,
config
:
MambaConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
scheduler_config
:
Optional
[
SchedulerConfig
]
=
None
,
)
->
None
:
assert
not
cache_config
.
enable_prefix_caching
,
\
"Mamba does not support prefix caching"
super
().
__init__
()
self
.
config
=
config
self
.
scheduler_config
=
scheduler_config
self
.
backbone
=
MambaModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
lora_config
=
lora_config
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
self
.
backbone
.
embeddings
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Optional
[
MambaCacheManager
]
=
None
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
):
if
self
.
mamba_cache
is
None
:
max_batch_size
=
(
_get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
max
(
_BATCH_SIZES_TO_CAPTURE
)
+
2
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
lm_head
.
weight
.
dtype
,
self
.
config
.
num_hidden_layers
,
max_batch_size
,
*
self
.
_get_mamba_cache_shape
())
(
mamba_cache_tensors
,
state_indices_tensor
,
)
=
self
.
mamba_cache
.
current_run_tensors
(
input_ids
,
attn_metadata
,
**
kwargs
)
mamba_cache_params
=
MambaCacheParams
(
mamba_cache_tensors
[
0
],
mamba_cache_tensors
[
1
],
state_indices_tensor
)
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
attn_metadata
,
mamba_cache_params
)
return
hidden_states
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
return
self
.
mamba_cache
.
copy_inputs_before_cuda_graphs
(
input_buffers
,
**
kwargs
)
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
return
self
.
mamba_cache
.
get_seqlen_agnostic_capture_inputs
(
batch_size
)
def
_get_mamba_cache_shape
(
self
)
->
Tuple
[
Tuple
[
int
,
int
],
Tuple
[
int
,
int
]]:
world_size
=
get_tensor_model_parallel_world_size
()
conv_state_shape
=
(
self
.
config
.
intermediate_size
//
world_size
,
self
.
config
.
conv_kernel
-
1
,
)
temporal_state_shape
=
(
self
.
config
.
intermediate_size
//
world_size
,
self
.
config
.
state_size
,
)
return
conv_state_shape
,
temporal_state_shape
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"A_log"
in
name
:
name
=
name
.
replace
(
"A_log"
,
"A"
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/mamba_cache.py
0 → 100644
View file @
6d2051cc
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
import
torch
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
@
dataclass
class
MambaCacheParams
:
conv_state
:
torch
.
Tensor
=
torch
.
Tensor
()
ssm_state
:
torch
.
Tensor
=
torch
.
Tensor
()
state_indices_tensor
:
torch
.
Tensor
=
torch
.
Tensor
()
def
at_layer_idx
(
self
,
layer_idx
):
return
MambaCacheParams
(
self
.
conv_state
[
layer_idx
],
self
.
ssm_state
[
layer_idx
],
self
.
state_indices_tensor
)
class
MambaCacheManager
:
def
__init__
(
self
,
dtype
,
num_mamba_layers
,
max_batch_size
,
conv_state_shape
,
temporal_state_shape
):
conv_state
=
torch
.
empty
(
size
=
(
num_mamba_layers
,
max_batch_size
)
+
conv_state_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
temporal_state
=
torch
.
empty
(
size
=
(
num_mamba_layers
,
max_batch_size
)
+
temporal_state_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
self
.
mamba_cache
=
(
conv_state
,
temporal_state
)
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self
.
mamba_cache_indices_mapping
:
Dict
[
str
,
Dict
[
int
,
int
]]
=
{}
self
.
free_cache_indices
=
list
(
range
(
max_batch_size
))
def
current_run_tensors
(
self
,
input_ids
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
):
"""
Return the tensors for the current run's conv and ssm state.
"""
if
"seqlen_agnostic_capture_inputs"
not
in
kwargs
:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
self
.
_release_finished_requests
(
finished_requests_ids
)
state_indices
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
finished_requests_ids
)
state_indices_tensor
=
torch
.
as_tensor
(
state_indices
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
mamba_cache_tensors
=
self
.
mamba_cache
else
:
# CUDA graph capturing runs
(
mamba_cache_tensors
,
state_indices_tensor
)
=
kwargs
[
"seqlen_agnostic_capture_inputs"
]
return
(
mamba_cache_tensors
,
state_indices_tensor
)
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert
all
(
key
in
kwargs
for
key
in
[
"request_ids_to_seq_ids"
,
"finished_requests_ids"
])
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
assert
"seqlen_agnostic_capture_inputs"
in
input_buffers
_
,
input_state_indices_buffer
=
input_buffers
[
"seqlen_agnostic_capture_inputs"
]
self
.
_release_finished_requests
(
finished_requests_ids
)
state_indices
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
finished_requests_ids
)
cuda_graph_pad_len
=
input_state_indices_buffer
.
shape
[
0
]
-
len
(
state_indices
)
state_indices
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_len
)
input_state_indices_buffer
.
copy_
(
torch
.
as_tensor
(
state_indices
,
dtype
=
torch
.
int32
,
device
=
"cuda"
))
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
state_indices_tensor
=
torch
.
as_tensor
([
PAD_SLOT_ID
]
*
batch_size
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
return
(
self
.
mamba_cache
,
state_indices_tensor
)
def
_copy_mamba_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
assert
len
(
self
.
mamba_cache
)
>
0
for
cache_t
in
self
.
mamba_cache
:
cache_t
[:,
to_index
].
copy_
(
cache_t
[:,
from_index
],
non_blocking
=
True
)
def
_assign_seq_id_to_cache_index
(
self
,
cur_rid
:
str
,
seq_id
:
int
,
finished_requests_ids
)
->
int
:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if
cur_rid
in
finished_requests_ids
:
# set as pad, do not allocate destination index
return
PAD_SLOT_ID
elif
cur_rid
not
in
self
.
mamba_cache_indices_mapping
:
destination_index
=
self
.
free_cache_indices
.
pop
()
self
.
mamba_cache_indices_mapping
[
cur_rid
]
=
{
seq_id
:
destination_index
}
return
destination_index
elif
seq_id
not
in
(
seq_ids2indices
:
=
self
.
mamba_cache_indices_mapping
[
cur_rid
]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists
=
next
(
iter
(
seq_ids2indices
.
values
()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index
=
self
.
free_cache_indices
.
pop
()
self
.
_copy_mamba_cache
(
from_index
=
index_exists
,
to_index
=
destination_index
)
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
=
destination_index
return
destination_index
else
:
# already exists
return
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
def
_prepare_current_run_mamba_cache
(
self
,
request_ids_to_seq_ids
:
Dict
[
str
,
list
[
int
]],
finished_requests_ids
:
List
[
str
])
->
List
[
int
]:
return
[
self
.
_assign_seq_id_to_cache_index
(
req_id
,
seq_id
,
finished_requests_ids
)
for
req_id
,
seq_ids
in
request_ids_to_seq_ids
.
items
()
for
seq_id
in
seq_ids
]
def
_release_finished_requests
(
self
,
finished_seq_groups_req_ids
:
List
[
str
]):
for
req_id
in
finished_seq_groups_req_ids
:
if
req_id
in
self
.
mamba_cache_indices_mapping
:
for
seq_id
in
self
.
mamba_cache_indices_mapping
[
req_id
]:
self
.
free_cache_indices
.
append
(
self
.
mamba_cache_indices_mapping
[
req_id
][
seq_id
])
self
.
mamba_cache_indices_mapping
.
pop
(
req_id
)
vllm/model_executor/models/minicpm.py
View file @
6d2051cc
...
...
@@ -22,7 +22,7 @@
# limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
import
math
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -30,10 +30,10 @@ from transformers import PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
FatreluAndMul
,
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -41,8 +41,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -52,7 +51,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
class
MiniCPMMoE
(
nn
.
Module
):
...
...
@@ -151,6 +152,7 @@ class MiniCPMMLP(nn.Module):
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
hidden_act_param
:
float
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -162,10 +164,13 @@ class MiniCPMMLP(nn.Module):
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
if
hidden_act
==
"silu"
:
self
.
act_fn
=
SiluAndMul
()
elif
hidden_act
==
"fatrelu"
:
self
.
act_fn
=
FatreluAndMul
(
threshold
=
hidden_act_param
)
else
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
"Only silu and fatrelu are supported for now."
)
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
...
...
@@ -264,7 +269,7 @@ class MiniCPMDecoderLayer(nn.Module):
def
__init__
(
self
,
config
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
...
...
@@ -303,6 +308,7 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
self
.
config
.
intermediate_size
,
hidden_act
=
self
.
config
.
hidden_act
,
hidden_act_param
=
getattr
(
self
.
config
,
"hidden_act_param"
,
0.
),
quant_config
=
self
.
quant_config
,
)
else
:
...
...
@@ -346,10 +352,11 @@ class MiniCPMModel(nn.Module):
def
__init__
(
self
,
config
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
@@ -365,15 +372,24 @@ class MiniCPMModel(nn.Module):
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
self
.
_init_layers
()
self
.
_init_layers
(
prefix
,
config
,
cache_config
,
quant_config
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
self
.
config
.
hidden_size
))
def
_init_layers
(
self
):
self
.
layers
=
nn
.
ModuleList
([
MiniCPMDecoderLayer
(
self
.
config
,
self
.
cache_config
,
self
.
quant_config
)
for
_
in
range
(
self
.
config
.
num_hidden_layers
)
])
def
_init_layers
(
self
,
prefix
:
str
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
],
quant_config
:
Optional
[
QuantizationConfig
],
):
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
MiniCPMDecoderLayer
(
config
,
cache_config
,
quant_config
),
prefix
=
f
"
{
prefix
}
.layers"
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
embedding
=
self
.
embed_tokens
(
input_ids
)
...
...
@@ -387,27 +403,36 @@ class MiniCPMModel(nn.Module):
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
len
(
self
.
layer
s
)
):
for
i
in
range
(
self
.
start_layer
,
self
.
end_
layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
MiniCPMForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
class
MiniCPMForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -454,22 +479,25 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
if
not
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
ParallelLMHead
(
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
)
self
.
lm_head
=
ParallelLMHead
(
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
lm_head
.
tie_weights
(
self
.
model
.
embed_tokens
)
self
.
scale_width
=
self
.
config
.
hidden_size
/
self
.
config
.
dim_model_base
self
.
logits_processor
=
LogitsProcessor
(
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
_init_model
(
self
):
self
.
model
=
MiniCPMModel
(
config
=
self
.
config
,
...
...
@@ -484,7 +512,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
return
hidden_states
...
...
@@ -495,11 +523,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
hidden_states
=
hidden_states
/
self
.
scale_width
if
self
.
config
.
tie_word_embeddings
:
lm_head
=
self
.
model
.
embed_tokens
else
:
lm_head
=
self
.
lm_head
logits
=
self
.
logits_processor
(
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
@@ -548,6 +572,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
...
@@ -557,6 +583,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
...
...
@@ -568,6 +596,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
vllm/model_executor/models/minicpm3.py
View file @
6d2051cc
...
...
@@ -26,6 +26,7 @@ from typing import Any, Dict, Optional
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
...
...
@@ -34,19 +35,20 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.models.minicpm
import
(
MiniCPMDecoderLayer
,
MiniCPMForCausalLM
,
MiniCPMModel
)
from
.utils
import
make_layers
class
MiniCPM3Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
config
:
PretrainedConfig
,
hidden_size
:
int
,
num_heads
:
int
,
qk_nope_head_dim
:
int
,
...
...
@@ -199,15 +201,43 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):
class
MiniCPM3Model
(
MiniCPMModel
):
def
_init_layers
(
self
):
self
.
layers
=
nn
.
ModuleList
([
MiniCPM3DecoderLayer
(
self
.
config
,
self
.
cache_config
,
self
.
quant_config
)
for
_
in
range
(
self
.
config
.
num_hidden_layers
)
])
def
_init_layers
(
self
,
prefix
:
str
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
],
quant_config
:
Optional
[
QuantizationConfig
],
):
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
MiniCPM3DecoderLayer
(
config
,
cache_config
,
quant_config
),
prefix
=
f
"
{
prefix
}
.layers"
)
class
MiniCPM3ForCausalLM
(
MiniCPMForCausalLM
):
packed_modules_mapping
=
{
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"kv_a_proj_with_mqa"
,
"q_a_proj"
,
"q_b_proj"
,
"kv_b_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
"embed_tokens"
,
"lm_head"
,
]
# `embedding_modules` and `embedding_padding_modules`
# are inherited from MiniCPMForCausalLM
def
_init_model
(
self
):
self
.
model
=
MiniCPM3Model
(
config
=
self
.
config
,
...
...
vllm/model_executor/models/minicpmv.py
View file @
6d2051cc
...
...
@@ -24,33 +24,33 @@
import
math
import
re
from
functools
import
partial
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
TypedDict
)
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch.types
from
PIL
import
Image
from
torch
import
nn
from
torch.nn.init
import
trunc_normal_
from
transformers
import
PretrainedConfig
from
typing_extensions
import
NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
MultiModalConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
InputContext
,
token_inputs
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.resampler
import
(
Resampler2
,
from
vllm.model_executor.layers.resampler
import
(
BaseResampler
,
Resampler2
,
get_2d_sincos_pos_embed
)
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.models.minicpm
import
MiniCPMModel
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.qwen2
import
Qwen2Model
from
vllm.model_executor.models.utils
import
LLMWrapper
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
...
...
@@ -59,16 +59,19 @@ from vllm.multimodal.utils import cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
.idefics2_vision_model
import
Idefics2VisionTransformer
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
is_pp_missing_parameter
_KEYS_TO_MODIFY_MAPPING
=
{
"llm.lm_head"
:
"lm_head"
,
"llm.model"
:
"llm"
,
}
RawImageType
=
Union
[
Image
.
Image
,
torch
.
Tensor
]
class
MiniCPMVImageInput
(
TypedDict
):
class
MiniCPMVRawImageInput
(
TypedDict
):
"""Input mapper input with auxiliary data for computing image bounds."""
image
:
Image
.
Imag
e
image
:
Raw
Image
Typ
e
# Image bounds token ids in 0-dim scaler tensor.
im_start_id
:
torch
.
Tensor
...
...
@@ -78,7 +81,8 @@ class MiniCPMVImageInput(TypedDict):
class
MiniCPMVImagePixelInputs
(
TypedDict
):
pixel_values
:
List
[
torch
.
Tensor
]
type
:
Literal
[
"pixel_values"
]
data
:
List
[
torch
.
Tensor
]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
...
...
@@ -101,59 +105,28 @@ class MiniCPMVImagePixelInputs(TypedDict):
"""
DEFAULT_LN
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
class
MiniCPMVImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
data
:
torch
.
Tensor
"""
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor.
"""
class
BaseResampler
(
nn
.
Module
):
image_bounds
:
torch
.
Tensor
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
Shape: `(batch_size * num_images, 2)`
This should be in `(start, stop)` format.
"""
def
__init__
(
self
,
num_queries
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
kv_dim
:
Optional
[
int
]
=
None
,
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
]
=
DEFAULT_LN
,
)
->
None
:
super
().
__init__
()
self
.
num_queries
=
num_queries
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
MiniCPMVImageInputs
=
Union
[
MiniCPMVImagePixelInputs
,
MiniCPMVImageEmbeddingInputs
]
self
.
query
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_queries
,
embed_dim
))
trunc_normal_
(
self
.
query
,
std
=
0.02
)
if
kv_dim
is
not
None
and
kv_dim
!=
embed_dim
:
self
.
kv_proj
=
ReplicatedLinear
(
kv_dim
,
embed_dim
,
bias
=
False
)
else
:
# Maintain the same return value with ReplicatedLinear.forward
self
.
kv_proj
=
lambda
*
args
,
**
kwargs
:
(
nn
.
Identity
()(
*
args
,
**
kwargs
),
None
,
)
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
,
num_heads
)
self
.
ln_q
=
norm_layer
(
embed_dim
)
self
.
ln_kv
=
norm_layer
(
embed_dim
)
self
.
ln_post
=
norm_layer
(
embed_dim
)
self
.
proj
=
nn
.
Parameter
(
(
embed_dim
**-
0.5
)
*
torch
.
randn
(
embed_dim
,
embed_dim
))
def
_init_weights
(
self
,
m
:
nn
.
Module
)
->
None
:
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
_repeat
(
self
,
query
,
N
:
int
):
return
query
.
unsqueeze
(
1
).
repeat
(
1
,
N
,
1
)
DEFAULT_LN
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
class
Resampler2_5
(
BaseResampler
):
...
...
@@ -246,22 +219,22 @@ class Resampler2_5(BaseResampler):
def
_build_image_input
(
ctx
:
InputContext
,
image
:
Image
.
Imag
e
)
->
MiniCPMVImageInput
:
image
:
Raw
Image
Typ
e
)
->
MiniCPMV
Raw
ImageInput
:
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
,
trust_remote_code
=
ctx
.
model_config
.
trust_remote_code
)
if
hasattr
(
tokenizer
,
"slice_start_id"
):
return
MiniCPMVImageInput
(
return
MiniCPMV
Raw
ImageInput
(
image
=
image
,
im_start_id
=
torch
.
tensor
(
tokenizer
.
im_start_id
),
im_end_id
=
torch
.
tensor
(
tokenizer
.
im_end_id
),
slice_start_id
=
torch
.
tensor
(
tokenizer
.
slice_start_id
),
slice_end_id
=
torch
.
tensor
(
tokenizer
.
slice_end_id
))
else
:
return
MiniCPMVImageInput
(
image
=
image
,
im_start_id
=
torch
.
tensor
(
tokenizer
.
im_start_id
),
im_end_id
=
torch
.
tensor
(
tokenizer
.
im_end_id
))
return
MiniCPMV
Raw
ImageInput
(
image
=
image
,
im_start_id
=
torch
.
tensor
(
tokenizer
.
im_start_id
),
im_end_id
=
torch
.
tensor
(
tokenizer
.
im_end_id
))
def
get_version_by_config
(
config
:
PretrainedConfig
)
->
Tuple
[
int
,
...]:
...
...
@@ -284,7 +257,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
def
dummy_seq_data_for_minicpmv
(
seq_len
:
int
,
num_images
:
int
):
return
SequenceData
.
from_token_counts
((
0
,
seq_len
))
return
SequenceData
.
from_
prompt_
token_counts
((
0
,
seq_len
))
def
dummy_image_for_minicpmv
(
ctx
:
InputContext
,
hf_config
:
PretrainedConfig
,
...
...
@@ -307,10 +280,10 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
return
seq_data
,
mm_data
def
input_processor_for_minicpmv
(
ctx
:
InputContext
,
llm_
inputs
:
LLM
Inputs
):
multi_modal_data
=
llm_
inputs
.
get
(
"multi_modal_data"
)
def
input_processor_for_minicpmv
(
ctx
:
InputContext
,
inputs
:
DecoderOnly
Inputs
):
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
llm_
inputs
return
inputs
model_config
=
ctx
.
model_config
version
=
get_version_by_config
(
model_config
.
hf_config
)
tokenizer
=
cached_get_tokenizer
(
...
...
@@ -325,27 +298,32 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return
image_processor
.
\
get_slice_image_placeholder
(
image_size
,
num_image
)
prompt
=
llm_inputs
.
get
(
"prompt"
)
prompt
=
inputs
.
get
(
"prompt"
)
token_ids
=
inputs
.
get
(
"prompt_token_ids"
)
if
prompt
is
None
:
token_ids
=
llm_inputs
.
get
(
"prompt_token_ids"
)
prompt
=
tokenizer
.
decode
(
token_ids
)
pattern
=
"(<image>./</image>)"
images
=
multi_modal_data
[
"image"
]
if
isinstance
(
images
,
Image
.
Image
):
images
=
[
images
]
image_tags
=
re
.
findall
(
pattern
,
prompt
)
if
len
(
image_tags
)
==
0
:
new_token_ids
=
token_ids
new_prompt
=
prompt
else
:
if
isinstance
(
images
,
dict
):
image_size_list
=
images
.
get
(
"image_size_list"
)
images
=
[
images
.
get
(
"image_embeds"
)]
else
:
if
isinstance
(
images
,
Image
.
Image
):
images
=
[
images
]
image_size_list
=
[
image
.
size
for
image
in
images
]
text_chunks
=
prompt
.
split
(
pattern
)
new_prompt_chunks
:
List
[
str
]
=
[]
for
i
in
range
(
len
(
image
s
)):
for
i
in
range
(
len
(
image
_size_list
)):
new_prompt_chunks
+=
[
text_chunks
[
i
],
get_placeholder
(
image
s
[
i
].
size
,
i
)
get_placeholder
(
image
_size_list
[
i
]
,
i
)
]
new_prompt_chunks
.
append
(
text_chunks
[
-
1
])
new_prompt
=
""
.
join
(
new_prompt_chunks
)
...
...
@@ -355,12 +333,11 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
_build_image_input
(
ctx
,
image
)
for
image
in
images
]
llm_inputs
=
LLMI
nputs
(
return
token_i
nputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
multi_modal_data
=
multi_modal_data
,
)
return
llm_inputs
def
input_mapper_for_minicpmv
(
ctx
:
InputContext
,
data
:
object
):
...
...
@@ -375,9 +352,15 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object):
if
not
isinstance
(
data
,
list
):
raise
ValueError
(
"Image input must be list of MiniCPMVImageInput, got (%s)"
,
data
)
batch_data
=
image_processor
\
.
preprocess
([
img
[
"image"
]
for
img
in
data
],
return_tensors
=
"pt"
)
\
.
data
if
len
(
data
)
>
0
and
isinstance
(
data
[
0
][
'image'
],
torch
.
Tensor
):
batch_data
=
{
"image_embeds"
:
data
[
0
][
'image'
],
}
else
:
batch_data
=
image_processor
\
.
preprocess
([
img
[
"image"
]
for
img
in
data
],
return_tensors
=
"pt"
)
\
.
data
if
len
(
data
)
>
0
:
batch_data
[
"im_start_id"
]
=
data
[
0
][
"im_start_id"
]
...
...
@@ -389,7 +372,7 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object):
return
MultiModalInputs
(
batch_data
)
class
MiniCPMVBaseModel
(
nn
.
Module
,
SupportsMultiModal
):
class
MiniCPMVBaseModel
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated.
...
...
@@ -426,10 +409,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
llm
.
make_empty_intermediate_tensors
)
def
get_embedding
(
self
,
input_ids
:
torch
.
Tensor
,
image_inputs
:
Optional
[
MiniCPMVImage
Pixel
Inputs
],
image_inputs
:
Optional
[
MiniCPMVImageInputs
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
vlm_embedding
:
torch
.
Tensor
=
self
.
llm
.
embed_tokens
(
input_ids
)
if
hasattr
(
self
.
config
,
"scale_emb"
):
...
...
@@ -438,7 +424,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
if
image_inputs
is
None
:
# No image
vision_hidden_states
=
torch
.
tensor
([],
device
=
input_ids
.
device
)
else
:
vision_hidden_states
=
self
.
get_vision_hidden_states
(
image_inputs
)
if
image_inputs
[
"type"
]
==
"image_embeds"
:
vision_hidden_states
=
(
image_inputs
[
"data"
].
type
(
vlm_embedding
.
dtype
).
to
(
vlm_embedding
.
device
))
else
:
vision_hidden_states
=
self
.
get_vision_hidden_states
(
image_inputs
)
# See NOTE in _parse_and_validate_inputs
image_bounds
=
image_inputs
[
"image_bounds"
]
...
...
@@ -489,9 +480,23 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
self
,
input_ids
:
torch
.
Tensor
,
**
kwargs
:
object
,
)
->
Optional
[
MiniCPMVImage
Pixel
Inputs
]:
)
->
Optional
[
MiniCPMVImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
[])
tgt_sizes
=
kwargs
.
pop
(
"tgt_sizes"
,
[])
im_start_id
=
kwargs
.
pop
(
"im_start_id"
,
None
)
im_end_id
=
kwargs
.
pop
(
"im_end_id"
,
None
)
slice_start_id
=
kwargs
.
pop
(
"slice_start_id"
,
None
)
slice_end_id
=
kwargs
.
pop
(
"slice_end_id"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
image_embeds
is
not
None
:
return
MiniCPMVImageEmbeddingInputs
(
image_bounds
=
self
.
_get_image_bounds
(
input_ids
,
im_start_id
,
im_end_id
,
slice_start_id
,
slice_end_id
),
data
=
image_embeds
,
type
=
"image_embeds"
,
)
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
...
...
@@ -526,10 +531,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
if
len
(
pixel_values_flat
)
==
0
:
return
None
im_start_id
=
kwargs
.
pop
(
"im_start_id"
,
None
)
im_end_id
=
kwargs
.
pop
(
"im_end_id"
,
None
)
slice_start_id
=
kwargs
.
pop
(
"slice_start_id"
,
None
)
slice_end_id
=
kwargs
.
pop
(
"slice_end_id"
,
None
)
if
im_start_id
is
None
:
return
None
...
...
@@ -537,8 +538,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
image_bounds
=
self
.
_get_image_bounds
(
input_ids
,
im_start_id
,
im_end_id
,
slice_start_id
,
slice_end_id
),
pixel_values
=
pixel_values_flat
,
data
=
pixel_values_flat
,
tgt_sizes
=
torch
.
stack
(
tgt_sizes_flat
),
type
=
"pixel_values"
,
)
def
forward
(
...
...
@@ -550,9 +552,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
Any
,
)
->
torch
.
Tensor
:
image_inputs
=
self
.
_parse_and_validate_inputs
(
input_ids
,
**
kwargs
)
if
intermediate_tensors
is
not
None
:
vlm_embeddings
=
None
else
:
image_inputs
=
self
.
_parse_and_validate_inputs
(
input_ids
,
**
kwargs
)
vlm_embeddings
,
_
=
self
.
get_embedding
(
input_ids
,
image_inputs
)
vlm_embeddings
,
_
=
self
.
get_embedding
(
input_ids
,
image_inputs
)
output
=
self
.
llm
(
input_ids
=
None
,
...
...
@@ -609,6 +614,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
if
is_pp_missing_parameter
(
name
.
replace
(
weight_name
,
param_name
),
self
):
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
...
@@ -616,11 +624,21 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
else
:
use_default_weight_loading
=
True
if
use_default_weight_loading
:
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""
Get the module prefix in multimodal models
"""
return
MultiModelKeys
.
from_string_field
(
language_model
=
"llm"
,
connector
=
"resampler"
,
tower_model
=
"vpm"
)
def
init_llm
(
self
,
config
:
PretrainedConfig
,
...
...
@@ -643,8 +661,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImage
Pixel
Inputs
)
->
torch
.
Tensor
:
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
is_default_weight_loading
(
self
,
name
:
str
)
->
bool
:
...
...
@@ -669,9 +687,11 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
nn
.
Module
:
return
MiniCPMModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
return
LLMWrapper
(
MiniCPMModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
),
name
=
"model"
)
def
init_vision_module
(
self
)
->
nn
.
Module
:
# TODO :refactor this vision model
...
...
@@ -697,6 +717,9 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
return
model
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
embed_tokens
(
input_ids
)
def
init_resampler
(
self
,
embed_dim
:
int
,
vision_dim
:
int
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
torch
.
float16
):
resampler
=
Resampler2
(
...
...
@@ -733,9 +756,9 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
res
.
append
(
self
.
resampler
(
vision_embedding
,
tgt_size
))
return
torch
.
vstack
(
res
)
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImage
Pixel
Inputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"
pixel_values
"
]
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"
data
"
]
return
self
.
get_vision_embedding
(
pixel_values
)
...
...
@@ -743,7 +766,34 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
return
"resampler"
in
name
or
"vpm"
in
name
class
MiniCPMV2_5
(
MiniCPMVBaseModel
):
class
MiniCPMV2_5
(
MiniCPMVBaseModel
,
SupportsLoRA
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
# vision encoder
"fc1"
,
"fc2"
,
"out_proj"
,
# language model
"qkv_proj"
,
# same name with vision encoder
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
# resampler
"kv_proj"
,
]
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__init__
(
self
,
...
...
@@ -751,6 +801,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
(
config
,
multimodal_config
,
cache_config
,
quant_config
)
assert
self
.
version
==
(
2
,
5
)
...
...
@@ -761,9 +812,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
nn
.
Module
:
return
LlamaModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
return
LLMWrapper
(
LlamaModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
),
name
=
"model"
)
def
init_vision_module
(
self
)
->
nn
.
Module
:
model
=
Idefics2VisionTransformer
(
self
.
config
.
vision_config
)
...
...
@@ -792,9 +844,9 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
vision_embedding
=
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
return
vision_embedding
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImage
Pixel
Inputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"
pixel_values
"
]
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"
data
"
]
tgt_sizes
=
data
[
"tgt_sizes"
]
device
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
device
...
...
@@ -825,7 +877,35 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
return
"resampler"
in
name
class
MiniCPMV2_6
(
MiniCPMVBaseModel
):
class
MiniCPMV2_6
(
MiniCPMVBaseModel
,
SupportsLoRA
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
# vision encoder
"fc1"
,
"fc2"
,
"out_proj"
,
# language model
"qkv_proj"
,
# same name with vision encoder
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
# resampler
"kv_proj"
,
]
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__init__
(
self
,
...
...
@@ -843,20 +923,15 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
nn
.
Module
:
return
Qwen2Model
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
return
LLMWrapper
(
Qwen2Model
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
),
name
=
"model"
)
def
init_vision_module
(
self
)
->
nn
.
Module
:
# A custom version of SiglipVisionTransformer, won't work with TP
from
vllm.model_executor.models.na_vit
import
SiglipVisionTransformer
if
self
.
config
.
_attn_implementation
==
"flash_attention_2"
:
self
.
config
.
vision_config
.
_attn_implementation
=
"flash_attention_2"
else
:
# not support sdpa
self
.
config
.
vision_config
.
_attn_implementation
=
"eager"
model
=
SiglipVisionTransformer
(
self
.
config
.
vision_config
)
model
=
Idefics2VisionTransformer
(
self
.
config
.
vision_config
)
if
self
.
config
.
drop_vision_last_layer
:
model
.
encoder
.
layers
=
model
.
encoder
.
layers
[:
-
1
]
return
model
...
...
@@ -870,7 +945,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
num_heads
=
embed_dim
//
128
,
kv_dim
=
vision_dim
,
)
return
resampler
def
get_vision_embedding
(
...
...
@@ -883,12 +957,12 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
pixel_values
,
patch_attention_mask
=
patch_attn_mask
,
tgt_sizes
=
tgt_sizes
,
)
.
last_hidden_state
)
return
vision_embedding
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImage
Pixel
Inputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"
pixel_values
"
]
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"
data
"
]
tgt_sizes
=
data
[
"tgt_sizes"
]
device
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
device
...
...
@@ -915,12 +989,12 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
all_pixel_values
.
type
(
dtype
),
patch_attention_mask
=
patch_attn_mask
,
tgt_sizes
=
tgt_sizes
,
)
.
last_hidden_state
)
return
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
def
is_default_weight_loading
(
self
,
name
:
str
)
->
bool
:
return
"resampler"
in
name
or
"vpm"
in
name
return
"resampler"
in
name
_SUPPORT_VERSION
=
{
...
...
@@ -934,20 +1008,25 @@ _SUPPORT_VERSION = {
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_minicpmv_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_minicpmv
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_minicpmv
)
class
MiniCPMV
(
MiniCPMVBaseModel
):
class
MiniCPMV
(
MiniCPMVBaseModel
,
SupportsLoRA
):
"""
Different versions of MiniCPMV use different visual encoders and LLMs,
which is not conducive to the current integration logic of LoRA and
bitsandbytes in vLLM. Therefore, it is necessary to separate them.
"""
def
__new__
(
cls
,
config
:
PretrainedConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
packed_modules_mapping
=
{}
supported_lora_modules
=
[]
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__new__
(
cls
,
config
:
PretrainedConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
):
if
not
hasattr
(
config
,
"version"
):
if
config
.
hidden_size
==
2304
and
config
.
query_num
==
64
:
version
=
(
2
,
0
)
...
...
vllm/model_executor/models/mixtral.py
View file @
6d2051cc
...
...
@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -36,8 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -47,8 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
from
.utils
import
is_pp_missing_parameter
,
make_layers
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
class
MixtralMoE
(
nn
.
Module
):
...
...
@@ -276,6 +276,9 @@ class MixtralModel(nn.Module):
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
forward
(
self
,
...
...
@@ -284,7 +287,7 @@ class MixtralModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
...
...
@@ -306,7 +309,7 @@ class MixtralModel(nn.Module):
return
hidden_states
class
MixtralForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
class
MixtralForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
fall_back_to_pt_during_load
=
False
packed_modules_mapping
=
{
...
...
@@ -319,10 +322,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"embed_tokens"
,
"lm_head"
,
"qkv_proj"
,
"o_proj"
,
"embed_tokens"
,
"lm_head"
,
"w1"
,
"w2"
,
"w3"
,
"gate"
]
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
...
...
@@ -365,6 +366,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
...
...
@@ -373,7 +376,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
return
hidden_states
...
...
@@ -387,20 +390,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata
)
return
logits
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
"residual"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
...
...
vllm/model_executor/models/mixtral_quant.py
View file @
6d2051cc
...
...
@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -31,7 +31,7 @@ from transformers import MixtralConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
...
@@ -39,8 +39,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -49,6 +48,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
class
MixtralMLP
(
nn
.
Module
):
...
...
@@ -296,6 +299,7 @@ class MixtralModel(nn.Module):
config
:
MixtralConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
...
...
@@ -305,13 +309,15 @@ class MixtralModel(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
MixtralDecoderLayer
(
config
,
cache_config
,
quant_config
=
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
MixtralDecoderLayer
(
config
,
cache_config
,
quant_config
=
quant_config
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
forward
(
self
,
...
...
@@ -319,19 +325,30 @@ class MixtralModel(nn.Module):
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
intermediate_tensors
:
Optional
[
IntermediateTensors
],
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
attn_metadata
,
residual
)
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
MixtralForCausalLM
(
nn
.
Module
):
class
MixtralForCausalLM
(
nn
.
Module
,
SupportsPP
):
fall_back_to_pt_during_load
=
False
def
__init__
(
...
...
@@ -351,6 +368,8 @@ class MixtralForCausalLM(nn.Module):
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
...
...
@@ -359,9 +378,9 @@ class MixtralForCausalLM(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
...
...
@@ -400,6 +419,8 @@ class MixtralForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
...
@@ -412,6 +433,8 @@ class MixtralForCausalLM(nn.Module):
if
(
"block_sparse_moe.experts."
in
name
and
name
not
in
params_dict
):
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
vllm/model_executor/models/mllama.py
View file @
6d2051cc
...
...
@@ -14,10 +14,10 @@
# limitations under the License.
"""PyTorch Mllama model."""
import
math
from
array
import
array
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
...
...
@@ -28,12 +28,16 @@ from transformers.modeling_outputs import (BaseModelOutput,
CausalLMOutputWithPast
)
from
transformers.models.mllama.image_processing_mllama
import
(
get_optimal_tiled_canvas
)
from
transformers.models.mllama.processing_mllama
import
(
get_cross_attention_token_mask
)
import
vllm.distributed.parallel_state
as
ps
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
EncoderDecoderInputs
,
InputContext
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -47,7 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
from
vllm.sequence
import
SequenceData
from
.clip
import
CLIPMLP
from
.interfaces
import
SupportsMultiModal
...
...
@@ -72,31 +76,45 @@ class MllamaImagePixelInputs(TypedDict):
# TODO: support LlamaImageEmbeddingInputs
def
input_processor_for_mllama
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
def
_get_num_image_in_last_group
(
prompt_token_ids
:
List
[
int
])
->
int
:
num_images
=
0
for
token_id
in
prompt_token_ids
[::
-
1
]:
if
token_id
==
MLLAMA_IMAGE_TOKEN_ID
:
num_images
+=
1
elif
num_images
>
0
:
break
return
num_images
def
input_processor_for_mllama
(
ctx
:
InputContext
,
inputs
:
Union
[
DecoderOnlyInputs
,
EncoderDecoderInputs
]):
# move encoder_prompt to prompt
if
llm_
inputs
.
get
(
"prompt"
)
is
None
:
llm_
inputs
[
"prompt"
]
=
llm_
inputs
[
"encoder_prompt"
]
llm_
inputs
[
"prompt_token_ids"
]
=
llm_
inputs
[
"encoder_prompt_token_ids"
]
if
inputs
.
get
(
"prompt"
)
is
None
:
inputs
[
"prompt"
]
=
inputs
[
"encoder_prompt"
]
inputs
[
"prompt_token_ids"
]
=
inputs
[
"encoder_prompt_token_ids"
]
# process multi-modal data
assert
"decoder_multi_modal_data"
not
in
llm_inputs
,
\
"multi-modal data should be put in encoder message of mllama"
multi_modal_data
=
llm_inputs
.
get
(
"encoder_multi_modal_data"
)
multi_modal_data
=
inputs
.
get
(
"encoder_multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
\
or
multi_modal_data
[
"image"
]
is
None
:
# text-only
llm_
inputs
[
"encoder_prompt"
]
=
""
llm_
inputs
[
"encoder_prompt_token_ids"
]
=
[]
llm_
inputs
[
"encoder_multi_modal_data"
]
=
{}
return
llm_
inputs
inputs
[
"encoder_prompt"
]
=
""
inputs
[
"encoder_prompt_token_ids"
]
=
[]
inputs
[
"encoder_multi_modal_data"
]
=
{}
return
inputs
# get num_tiles
if
isinstance
(
multi_modal_data
[
'image'
],
Image
.
Image
):
multi_modal_data
[
'image'
]
=
[
multi_modal_data
[
'image'
]]
# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tiles for those images.
num_decode_images
=
_get_num_image_in_last_group
(
inputs
[
"prompt_token_ids"
])
hf_config
=
ctx
.
model_config
.
hf_config
num_tiles
=
0
for
image
in
multi_modal_data
[
"image"
]:
for
image
in
multi_modal_data
[
"image"
]
[::
-
1
]
:
width
,
height
=
image
.
size
tile_size
=
hf_config
.
vision_config
.
image_size
canvas_height
,
canvas_width
=
get_optimal_tiled_canvas
(
...
...
@@ -108,17 +126,21 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
num_tiles_height
=
canvas_height
//
tile_size
num_tiles_width
=
canvas_width
//
tile_size
num_tiles
+=
num_tiles_height
*
num_tiles_width
num_decode_images
-=
1
if
num_decode_images
==
0
:
break
# set encoder prompt based on num_tiles
# Set encoder prompt length based on the number of tiles.
# This tells the block manager to allocate correct number
# of slots for encoder tokens.
assert
hf_config
.
vision_config
.
image_size
%
14
==
0
,
\
"chunk size should be multiple of 14"
token_per_chunk
=
(
hf_config
.
vision_config
.
image_size
//
14
)
**
2
+
1
num_tokens
=
num_tiles
*
token_per_chunk
llm_inputs
[
"encoder_prompt"
]
=
MLLAMA_IMAGE_TOKEN
*
num_tokens
llm_inputs
[
"encoder_prompt_token_ids"
]
=
[
MLLAMA_IMAGE_TOKEN_ID
]
*
num_tokens
inputs
[
"encoder_prompt"
]
=
MLLAMA_IMAGE_TOKEN
*
num_tokens
inputs
[
"encoder_prompt_token_ids"
]
=
[
MLLAMA_IMAGE_TOKEN_ID
]
*
num_tokens
return
llm_
inputs
return
inputs
def
get_max_mllama_image_tokens
(
ctx
:
InputContext
)
->
int
:
...
...
@@ -131,17 +153,18 @@ def dummy_decoder_seq_data(seq_len: int, num_images: int):
# <|image|> * num_images + 0 * (seq_len - num_images)
assert
seq_len
>=
num_images
,
\
"seq_len should be greater than or equal to num_images"
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
MLLAMA_IMAGE_TOKEN_ID
])
*
num_images
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
num_images
)
return
SequenceData
(
token_ids
)
return
SequenceData
.
from_prompt_token_counts
(
(
MLLAMA_IMAGE_TOKEN_ID
,
num_images
),
(
0
,
seq_len
-
num_images
),
)
def
dummy_encoder_seq_data
(
ctx
:
InputContext
,
num_images
:
int
):
num_tokens
=
get_max_mllama_image_tokens
(
ctx
)
*
num_images
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
MLLAMA_IMAGE_TOKEN_ID
])
*
num_tokens
return
SequenceData
(
token
_id
s
)
return
SequenceData
.
from_prompt_token_counts
(
(
MLLAMA_IMAGE_TOKEN_ID
,
num_
tokens
)
)
def
dummy_image
(
num_images
:
int
,
):
...
...
@@ -675,6 +698,7 @@ class MllamaTextCrossAttention(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
],
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
cross_attention_states
:
Optional
[
torch
.
Tensor
],
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
...
...
@@ -697,15 +721,71 @@ class MllamaTextCrossAttention(nn.Module):
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
head_dim
)
q
=
self
.
q_norm
(
q
)
output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
AttentionType
.
ENCODER_DECODER
)
if
attention_mask
is
not
None
:
output
=
self
.
attention_with_mask
(
q
,
k
,
v
,
kv_cache
,
attention_mask
,
kv_range_for_decode
,
attn_metadata
)
else
:
output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
AttentionType
.
ENCODER_DECODER
)
out
,
_
=
self
.
o_proj
(
output
)
return
out
def
attention_with_mask
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
kv_range_for_decode
:
List
[
Tuple
[
int
,
int
]],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
# Skip writing kv-cache for the initial profiling run.
if
len
(
kv_cache
.
shape
)
==
3
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_local_key_value_heads
,
self
.
head_dim
)
cached_k
=
torch
.
cat
([
k
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
cached_v
=
torch
.
cat
([
v
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
PagedAttention
.
write_to_paged_cache
(
cached_k
,
cached_v
,
key_cache
,
value_cache
,
attn_metadata
.
cross_slot_mapping
,
"auto"
,
1.0
,
1.0
)
# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
# standard causal mask, neither a block diagonal mask which
# can be optimized by xformers.BlockDiagonalMask.
# The mask is specially calculated for supporting multi
# images and interleaved images.
q_len
=
q
.
shape
[
0
]
kv_len
=
k
.
shape
[
0
]
q
=
q
.
transpose
(
0
,
1
).
view
(
self
.
num_local_key_value_heads
,
self
.
num_key_value_groups
,
q_len
,
self
.
head_dim
)
k
=
k
.
transpose
(
0
,
1
)[:,
None
,
:,
:].
expand
(
self
.
num_local_key_value_heads
,
self
.
num_key_value_groups
,
kv_len
,
self
.
head_dim
)
v
=
v
.
transpose
(
0
,
1
)[:,
None
,
:,
:].
expand
(
self
.
num_local_key_value_heads
,
self
.
num_key_value_groups
,
kv_len
,
self
.
head_dim
)
attention_mask
=
attention_mask
.
view
(
1
,
1
,
q_len
,
kv_len
)
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attention_mask
,
is_causal
=
False
)
output
=
output
.
permute
(
2
,
0
,
1
,
3
).
reshape
(
q_len
,
self
.
num_local_heads
*
self
.
head_dim
)
return
output
class
MllamaCrossAttentionDecoderLayer
(
torch
.
nn
.
Module
):
"""Cross-attention transformer block with tanh-gated attention
...
...
@@ -741,6 +821,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
hidden_states
:
torch
.
Tensor
,
cross_attention_states
:
torch
.
Tensor
,
cross_attention_mask
:
torch
.
Tensor
,
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
full_text_row_masked_out_mask
:
torch
.
Tensor
,
kv_cache
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
...
...
@@ -751,6 +832,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
hidden_states
=
self
.
cross_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
cross_attention_states
=
cross_attention_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
...
...
@@ -804,6 +886,7 @@ class MllamaTextModel(nn.Module):
positions
:
Optional
[
torch
.
LongTensor
],
cross_attention_states
:
Optional
[
torch
.
LongTensor
],
cross_attention_mask
:
Optional
[
torch
.
LongTensor
],
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
full_text_row_masked_out_mask
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
kv_caches
:
List
[
torch
.
Tensor
],
...
...
@@ -820,6 +903,7 @@ class MllamaTextModel(nn.Module):
hidden_states
=
hidden_states
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
kv_cache
=
kv_caches
[
idx
],
...
...
@@ -868,6 +952,7 @@ class MllamaForCausalLM(nn.Module):
positions
:
Optional
[
torch
.
LongTensor
],
cross_attention_states
:
Optional
[
torch
.
LongTensor
],
cross_attention_mask
:
Optional
[
torch
.
LongTensor
],
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
full_text_row_masked_out_mask
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
kv_caches
:
List
[
torch
.
Tensor
],
...
...
@@ -879,6 +964,7 @@ class MllamaForCausalLM(nn.Module):
positions
=
positions
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
...
...
@@ -1026,36 +1112,102 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
raise
AssertionError
(
"This line should be unreachable."
)
def
flat_encoder_result
(
self
,
cross_attention_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
):
attn_metadata
:
AttentionMetadata
,
actual_encoder_seq_lens
:
List
[
int
]):
cross_attention_states_flat
=
torch
.
zeros
(
sum
(
a
ttn_metadata
.
encoder_seq_lens
),
sum
(
a
ctual_
encoder_seq_lens
),
cross_attention_states
.
shape
[
-
1
],
device
=
cross_attention_states
.
device
,
dtype
=
cross_attention_states
.
dtype
)
start_pos
=
0
for
seq_len
,
vision_token_in_batch
in
zip
(
attn_metadata
.
encoder_seq_lens
,
cross_attention_states
):
for
seq_len
,
vision_token_in_batch
in
zip
(
actual_encoder_seq_lens
,
cross_attention_states
):
end_pos
=
start_pos
+
seq_len
cross_attention_states_flat
[
start_pos
:
end_pos
]
=
vision_token_in_batch
[:
seq_len
]
start_pos
=
end_pos
cross_attention_states
=
cross_attention_states_flat
return
cross_attention_states
def
get_cross_attention_states
(
self
,
image_inputs
:
MllamaImagePixelInputs
,
attn_metadata
:
AttentionMetadata
,
actual_encoder_seq_lens
:
List
[
int
],
)
->
Tuple
[
torch
.
Tensor
]:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values
=
image_inputs
[
'data'
]
aspect_ratio_ids
=
image_inputs
[
'aspect_ratio_ids'
]
aspect_ratio_mask
=
image_inputs
[
'aspect_ratio_mask'
]
cross_attention_states
=
self
.
vision_model
(
pixel_values
,
aspect_ratio_ids
,
aspect_ratio_mask
)
cross_attention_states
=
self
.
multi_modal_projector
(
cross_attention_states
)
bsz
,
_
,
_
,
_
,
image_token_dim
=
tuple
(
cross_attention_states
.
shape
)
cross_attention_states
=
cross_attention_states
.
view
(
bsz
,
-
1
,
image_token_dim
)
cross_attention_states
=
self
.
flat_encoder_result
(
cross_attention_states
,
attn_metadata
,
actual_encoder_seq_lens
)
return
cross_attention_states
def
get_cross_attention_mask
(
self
,
input_ids
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
num_tiles
:
List
[
List
[
int
]],
num_tokens_per_tile
:
int
,
dtype
:
torch
.
dtype
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
token_ids
=
input_ids
.
tolist
()
start
=
0
batch_token_ids
=
[]
for
seq_len
in
attn_metadata
.
seq_lens
:
batch_token_ids
.
append
(
token_ids
[
start
:
start
+
seq_len
])
start
+=
seq_len
sparse_mask
=
[
get_cross_attention_token_mask
(
t
,
MLLAMA_IMAGE_TOKEN_ID
)
for
t
in
batch_token_ids
]
# Skip generating cross-attention mask if all samples
# are text-only or have only 1 leading image.
if
skip_attention_mask
(
sparse_mask
):
return
None
,
None
dense_mask
,
tile_range_for_decode
=
\
convert_sparse_cross_attention_mask_to_dense
(
sparse_mask
,
num_tiles
,
attn_metadata
.
seq_lens
)
cross_attention_mask
=
\
convert_dense_cross_attention_mask_to_tensor
(
dense_mask
,
num_tokens_per_tile
,
input_ids
.
device
,
dtype
)
kv_range_for_decode
=
[[
t
[
0
]
*
num_tokens_per_tile
,
t
[
1
]
*
num_tokens_per_tile
]
for
t
in
tile_range_for_decode
]
return
cross_attention_mask
,
kv_range_for_decode
def
get_full_text_row_masked_out_mask
(
self
,
attn_metadata
:
AttentionMetadata
,
device
:
torch
.
device
,
)
->
torch
.
Tensor
:
full_text_row_masked_out_mask
=
torch
.
ones
(
(
attn_metadata
.
num_prefill_tokens
,
1
),
dtype
=
torch
.
bool
)
start_pos
=
0
for
seq_len
,
encoder_seq_len
in
zip
(
attn_metadata
.
seq_lens_tensor
.
cpu
(),
attn_metadata
.
encoder_seq_lens
):
for
seq_len
,
encoder_seq_len
in
zip
(
attn_metadata
.
seq_lens
,
attn_metadata
.
encoder_seq_lens
):
if
encoder_seq_len
==
0
:
full_text_row_masked_out_mask
[
start_pos
:
start_pos
+
seq_len
]
=
False
start_pos
+=
seq_len
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
.
to
(
cross_attention_states
.
device
)
return
cross_attention_states
,
full_text_row_masked_out_mask
device
)
return
full_text_row_masked_out_mask
def
forward
(
self
,
...
...
@@ -1069,39 +1221,54 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata
.
num_decode_tokens
>
0
:
raise
ValueError
(
"Chunk prefill not supported"
)
image_inputs
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
cross_attention_states
=
None
cross_attention_mask
=
None
kv_range_for_decode
=
None
# For 1) text-only prefill and decode, 2) image-present decode.
if
image_inputs
is
None
:
cross_attention_mask
=
None
full_text_row_masked_out_mask
=
(
attn_metadata
.
encoder_seq_lens_tensor
!=
0
).
reshape
(
-
1
,
1
).
to
(
input_ids
.
device
)
cross_attention_states
=
None
skip_cross_attention
=
max
(
attn_metadata
.
encoder_seq_lens
)
==
0
# For image-present prefill.
else
:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values
=
image_inputs
[
'data'
]
aspect_ratio_ids
=
image_inputs
[
'aspect_ratio_ids'
]
aspect_ratio_mask
=
image_inputs
[
'aspect_ratio_mask'
]
cross_attention_states
=
self
.
vision_model
(
pixel_values
,
aspect_ratio_ids
,
aspect_ratio_mask
)
cross_attention_states
=
self
.
multi_modal_projector
(
cross_attention_states
)
bsz
,
_
,
_
,
_
,
image_token_dim
=
tuple
(
cross_attention_states
.
shape
)
cross_attention_states
=
cross_attention_states
.
view
(
bsz
,
-
1
,
image_token_dim
)
cross_attention_states
,
full_text_row_masked_out_mask
=
\
self
.
flat_encoder_result
(
cross_attention_states
,
attn_metadata
)
skip_cross_attention
=
False
# TODO: support multi-image by this mask
cross_attention_mask
=
None
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See input_processor_for_mllama() for more details.
num_tiles_tensor
=
kwargs
.
pop
(
"num_tiles"
)
num_tiles
=
[
t
[
0
].
tolist
()
for
t
in
num_tiles_tensor
]
num_tokens_per_tile
=
(
self
.
image_size
//
14
)
**
2
+
1
actual_encoder_seq_lens
=
[
sum
(
num_tile
)
*
num_tokens_per_tile
for
num_tile
in
num_tiles
]
for
actual_len
,
last_group_len
in
zip
(
actual_encoder_seq_lens
,
attn_metadata
.
encoder_seq_lens
):
assert
actual_len
>=
last_group_len
cross_attention_states
=
self
.
get_cross_attention_states
(
image_inputs
,
attn_metadata
,
actual_encoder_seq_lens
)
full_text_row_masked_out_mask
=
\
self
.
get_full_text_row_masked_out_mask
(
attn_metadata
,
input_ids
.
device
)
cross_attention_mask
,
kv_range_for_decode
=
\
self
.
get_cross_attention_mask
(
input_ids
,
attn_metadata
,
num_tiles
,
num_tokens_per_tile
,
cross_attention_states
.
dtype
)
outputs
=
self
.
language_model
(
input_ids
=
input_ids
,
positions
=
positions
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
...
...
@@ -1140,3 +1307,76 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
skip_attention_mask
(
sparse_mask
:
List
[
List
[
int
]])
->
bool
:
for
mask
in
sparse_mask
:
# Skip text-only samples.
if
len
(
mask
)
==
0
:
continue
# If the sample contains more than 1 images,
# we can't skip mask.
if
len
(
mask
)
!=
1
:
return
False
# If the sample contains only 1 image,
# but the image is not the leading one,
# we can't skip mask.
if
mask
[
0
][
0
]
!=
0
or
mask
[
0
][
1
]
!=
-
1
:
return
False
return
True
def
convert_sparse_cross_attention_mask_to_dense
(
sparse_mask
:
List
[
List
[
List
[
int
]]],
num_tiles
:
List
[
List
[
int
]],
lengths
:
List
[
int
],
)
->
Tuple
[
np
.
ndarray
,
List
[
Tuple
[
int
,
int
]]]:
total_length
=
sum
(
lengths
)
total_tiles
=
sum
([
sum
(
tiles
)
for
tiles
in
num_tiles
])
dense_mask
=
np
.
zeros
(
shape
=
(
total_length
,
total_tiles
),
dtype
=
np
.
int64
)
# A list of ranges, range[i] = [start, end] means
# if the i-th sample has N tiles in total, the tiles[start, end]
# will be used for cross-attention decoding.
tile_range_for_decode
=
[]
seq_start
=
0
tile_start
=
0
for
masks
,
tiles
,
length
in
zip
(
sparse_mask
,
num_tiles
,
lengths
):
ts
,
td
=
-
1
,
0
for
mask
,
tile
in
zip
(
masks
,
tiles
):
if
len
(
mask
)
!=
2
:
continue
start
,
end
=
mask
end
=
min
(
end
,
length
)
if
end
==
-
1
:
end
=
length
if
end
==
length
:
if
ts
==
-
1
:
ts
=
tile_start
td
+=
tile
dense_mask
[
seq_start
+
start
:
seq_start
+
end
,
tile_start
:
tile_start
+
tile
]
=
1
tile_start
+=
tile
tile_range_for_decode
.
append
((
ts
,
ts
+
td
))
seq_start
+=
length
return
dense_mask
,
tile_range_for_decode
def
convert_dense_cross_attention_mask_to_tensor
(
cross_attention_token_mask
:
np
.
ndarray
,
num_tokens_per_tile
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
mask
=
torch
.
tensor
(
cross_attention_token_mask
,
dtype
=
dtype
,
device
=
device
)
mask
=
mask
.
repeat_interleave
(
num_tokens_per_tile
,
dim
=
1
)
mask
=
1.0
-
mask
mask
=
mask
.
masked_fill
(
mask
.
to
(
torch
.
bool
),
torch
.
finfo
(
dtype
).
min
)
ninf
=
torch
.
finfo
(
dtype
).
min
full_text_mask
=
((
mask
!=
ninf
).
any
(
dim
=-
1
).
type_as
(
mask
)[...,
None
])
mask
*=
full_text_mask
# (num_prompt_tokens, num_encoder_tokens)
return
mask
vllm/model_executor/models/module_mapping.py
0 → 100644
View file @
6d2051cc
# Adapted from
# https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Union
@
dataclass
class
ModelKeys
:
model_type
:
str
=
None
module_list
:
str
=
None
embedding
:
str
=
None
mlp
:
str
=
None
down_proj
:
str
=
None
attention
:
str
=
None
o_proj
:
str
=
None
q_proj
:
str
=
None
k_proj
:
str
=
None
v_proj
:
str
=
None
qkv_proj
:
str
=
None
qk_proj
:
str
=
None
qa_proj
:
str
=
None
qb_proj
:
str
=
None
kva_proj
:
str
=
None
kvb_proj
:
str
=
None
output
:
str
=
None
@
dataclass
class
MultiModelKeys
(
ModelKeys
):
language_model
:
List
[
str
]
=
field
(
default_factory
=
list
)
connector
:
List
[
str
]
=
field
(
default_factory
=
list
)
# vision tower and audio tower
tower_model
:
List
[
str
]
=
field
(
default_factory
=
list
)
generator
:
List
[
str
]
=
field
(
default_factory
=
list
)
@
staticmethod
def
from_string_field
(
language_model
:
Union
[
str
,
List
[
str
]]
=
None
,
connector
:
Union
[
str
,
List
[
str
]]
=
None
,
tower_model
:
Union
[
str
,
List
[
str
]]
=
None
,
generator
:
Union
[
str
,
List
[
str
]]
=
None
,
**
kwargs
)
->
'MultiModelKeys'
:
def
to_list
(
value
):
if
value
is
None
:
return
[]
return
[
value
]
if
isinstance
(
value
,
str
)
else
list
(
value
)
return
MultiModelKeys
(
language_model
=
to_list
(
language_model
),
connector
=
to_list
(
connector
),
tower_model
=
to_list
(
tower_model
),
generator
=
to_list
(
generator
),
**
kwargs
)
Prev
1
…
19
20
21
22
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment