Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2e2000f3
Unverified
Commit
2e2000f3
authored
Aug 21, 2025
by
Paul Pak
Committed by
GitHub
Aug 21, 2025
Browse files
[Model] Add LFM2 architecture (#22845)
Signed-off-by:
Paul Pak
<
paulpak58@gmail.com
>
parent
31282401
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
960 additions
and
8 deletions
+960
-8
docs/models/supported_models.md
docs/models/supported_models.md
+1
-0
tests/models/language/generation/test_hybrid.py
tests/models/language/generation/test_hybrid.py
+25
-8
tests/models/registry.py
tests/models/registry.py
+2
-0
tests/models/test_initialization.py
tests/models/test_initialization.py
+2
-0
vllm/config/compilation.py
vllm/config/compilation.py
+1
-0
vllm/model_executor/layers/mamba/mamba_utils.py
vllm/model_executor/layers/mamba/mamba_utils.py
+24
-0
vllm/model_executor/layers/mamba/short_conv.py
vllm/model_executor/layers/mamba/short_conv.py
+262
-0
vllm/model_executor/models/lfm2.py
vllm/model_executor/models/lfm2.py
+557
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/v1/attention/backends/mamba_selectors.py
vllm/v1/attention/backends/mamba_selectors.py
+4
-0
vllm/v1/attention/backends/short_conv_attn.py
vllm/v1/attention/backends/short_conv_attn.py
+81
-0
No files found.
docs/models/supported_models.md
View file @
2e2000f3
...
...
@@ -373,6 +373,7 @@ th {
|
`InternLM3ForCausalLM`
| InternLM3 |
`internlm/internlm3-8b-instruct`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`JAISLMHeadModel`
| Jais |
`inceptionai/jais-13b`
,
`inceptionai/jais-13b-chat`
,
`inceptionai/jais-30b-v3`
,
`inceptionai/jais-30b-chat-v3`
, etc. | | ✅︎ | ✅︎ |
|
`JambaForCausalLM`
| Jamba |
`ai21labs/AI21-Jamba-1.5-Large`
,
`ai21labs/AI21-Jamba-1.5-Mini`
,
`ai21labs/Jamba-v0.1`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Lfm2ForCausalLM`
| LFM2 |
`LiquidAI/LFM2-1.2B`
,
`LiquidAI/LFM2-700M`
,
`LiquidAI/LFM2-350M`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`LlamaForCausalLM`
| Llama 3.1, Llama 3, Llama 2, LLaMA, Yi |
`meta-llama/Meta-Llama-3.1-405B-Instruct`
,
`meta-llama/Meta-Llama-3.1-70B`
,
`meta-llama/Meta-Llama-3-70B-Instruct`
,
`meta-llama/Llama-2-70b-hf`
,
`01-ai/Yi-34B`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`MambaForCausalLM`
| Mamba |
`state-spaces/mamba-130m-hf`
,
`state-spaces/mamba-790m-hf`
,
`state-spaces/mamba-2.8b-hf`
, etc. | | ✅︎ | ✅︎ |
|
`Mamba2ForCausalLM`
| Mamba2 |
`mistralai/Mamba-Codestral-7B-v0.1`
, etc. | | ✅︎ | ✅︎ |
...
...
tests/models/language/generation/test_hybrid.py
View file @
2e2000f3
...
...
@@ -31,6 +31,7 @@ HYBRID_MODELS = [
"hmellor/tiny-random-BambaForCausalLM"
,
"ibm-granite/granite-4.0-tiny-preview"
,
"tiiuae/Falcon-H1-0.5B-Base"
,
"LiquidAI/LFM2-1.2B"
,
]
HF_UNSUPPORTED_MODELS
=
[
...
...
@@ -52,6 +53,7 @@ V1_SUPPORTED_MODELS = [
"hmellor/tiny-random-BambaForCausalLM"
,
"ibm-granite/granite-4.0-tiny-preview"
,
"tiiuae/Falcon-H1-0.5B-Base"
,
"LiquidAI/LFM2-1.2B"
,
]
FULL_CUDA_GRAPH_MODELS
=
[
...
...
@@ -59,6 +61,10 @@ FULL_CUDA_GRAPH_MODELS = [
"Zyphra/Zamba2-1.2B-instruct"
,
]
V0_UNSUPPORTED_MODELS
=
[
"LiquidAI/LFM2-1.2B"
,
]
# Avoid OOM
MAX_NUM_SEQS
=
4
...
...
@@ -94,9 +100,12 @@ def test_models(
else
:
hf_outputs
=
None
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
vllm_v0_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
if
model
not
in
V0_UNSUPPORTED_MODELS
:
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
vllm_v0_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
else
:
vllm_v0_outputs
=
None
if
model
in
V1_SUPPORTED_MODELS
:
with
monkeypatch
.
context
()
as
m
:
...
...
@@ -112,7 +121,7 @@ def test_models(
else
:
vllm_v1_outputs
=
None
if
hf_outputs
is
not
None
:
if
hf_outputs
is
not
None
and
vllm_v0_outputs
is
not
None
:
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_v0_outputs
,
...
...
@@ -122,6 +131,7 @@ def test_models(
if
model
in
V1_SUPPORTED_MODELS
:
ref_outputs
=
hf_outputs
if
hf_outputs
is
not
None
else
vllm_v0_outputs
assert
ref_outputs
is
not
None
check_logprobs_close
(
outputs_0_lst
=
ref_outputs
,
outputs_1_lst
=
vllm_v1_outputs
,
...
...
@@ -140,6 +150,9 @@ def test_batching(
max_tokens
:
int
,
num_logprobs
:
int
,
)
->
None
:
if
model
in
V0_UNSUPPORTED_MODELS
:
pytest
.
skip
(
f
"Unsupported V0 Engine. Skipping `test_batching` on
{
model
}
."
)
try
:
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model
)
...
...
@@ -392,9 +405,12 @@ def test_full_cuda_graph(
else
:
hf_outputs
=
None
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
vllm_v0_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
if
model
not
in
V0_UNSUPPORTED_MODELS
:
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
vllm_v0_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
else
:
vllm_v0_outputs
=
None
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
...
...
@@ -408,7 +424,7 @@ def test_full_cuda_graph(
vllm_v1_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
if
hf_outputs
is
not
None
:
if
hf_outputs
is
not
None
and
vllm_v0_outputs
is
not
None
:
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_v0_outputs
,
...
...
@@ -417,6 +433,7 @@ def test_full_cuda_graph(
)
ref_outputs
=
hf_outputs
if
hf_outputs
is
not
None
else
vllm_v0_outputs
assert
ref_outputs
is
not
None
check_logprobs_close
(
outputs_0_lst
=
ref_outputs
,
outputs_1_lst
=
vllm_v1_outputs
,
...
...
tests/models/registry.py
View file @
2e2000f3
...
...
@@ -230,6 +230,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"tiny"
:
"ai21labs/Jamba-tiny-dev"
,
"random"
:
"ai21labs/Jamba-tiny-random"
,
# noqa: E501
}),
"Lfm2ForCausalLM"
:
_HfExamplesInfo
(
"LiquidAI/LFM2-1.2B"
,
min_transformers_version
=
"4.54"
),
"LlamaForCausalLM"
:
_HfExamplesInfo
(
"meta-llama/Llama-3.2-1B-Instruct"
,
extras
=
{
"guard"
:
"meta-llama/Llama-Guard-3-1B"
,
# noqa: E501
"hermes"
:
"NousResearch/Hermes-3-Llama-3.1-8B"
,
# noqa: E501
...
...
tests/models/test_initialization.py
View file @
2e2000f3
...
...
@@ -95,6 +95,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
HF_EXAMPLE_MODELS
.
get_supported_archs
())
def
test_can_initialize
(
model_arch
:
str
,
monkeypatch
:
pytest
.
MonkeyPatch
):
if
model_arch
==
"Lfm2ForCausalLM"
:
pytest
.
skip
(
"Skipping until test supports V1-only models"
)
can_initialize
(
model_arch
,
monkeypatch
,
HF_EXAMPLE_MODELS
)
...
...
vllm/config/compilation.py
View file @
2e2000f3
...
...
@@ -337,6 +337,7 @@ class CompilationConfig:
"vllm.unified_attention_with_output"
,
"vllm.mamba_mixer2"
,
"vllm.mamba_mixer"
,
"vllm.short_conv"
,
]
def
compute_hash
(
self
)
->
str
:
...
...
vllm/model_executor/layers/mamba/mamba_utils.py
View file @
2e2000f3
...
...
@@ -54,6 +54,16 @@ class MambaStateDtypeCalculator:
return
(
conv_state_dtype
,
temporal_state_dtype
)
@
classmethod
def
short_conv_state_dtype
(
cls
,
model_dtype
:
Union
[
ModelDType
,
torch
.
dtype
],
mamba_cache_dtype
:
MambaDType
,
)
->
tuple
[
torch
.
dtype
,
...]:
conv_state_dtype
=
get_kv_cache_torch_dtype
(
mamba_cache_dtype
,
model_dtype
)
return
(
conv_state_dtype
,
)
class
MambaStateShapeCalculator
:
...
...
@@ -122,6 +132,20 @@ class MambaStateShapeCalculator:
tp_world_size
),
head_dim
,
state_size
)
return
conv_state_shape
,
temporal_state_shape
@
classmethod
def
short_conv_state_shape
(
cls
,
tp_world_size
:
int
,
intermediate_size
:
int
,
conv_kernel
:
int
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
]]:
conv_dim
=
divide
(
intermediate_size
,
tp_world_size
)
conv_state_shape
=
(
conv_kernel
-
1
,
conv_dim
)
if
not
use_v1
:
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
return
(
conv_state_shape
,
)
@
classmethod
def
extra_groups_for_head_shards
(
cls
,
ngroups
:
int
,
tp_size
:
int
):
"""Compute the increase in group numbers to account for
...
...
vllm/model_executor/layers/mamba/short_conv.py
0 → 100644
View file @
2e2000f3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
torch
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
ModelConfig
,
get_current_vllm_config
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
update_metadata
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_update
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.v1.attention.backends.short_conv_attn
import
(
ShortConvAttentionMetadata
)
@
CustomOp
.
register
(
"short_conv"
)
class
ShortConv
(
MambaBase
,
CustomOp
):
def
__init__
(
self
,
config
,
dim
:
int
,
layer_idx
:
int
,
model_config
:
Optional
[
ModelConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
config
=
config
self
.
layer_idx
=
layer_idx
self
.
conv_dim
=
dim
self
.
L_cache
=
config
.
conv_L_cache
self
.
bias
=
config
.
conv_bias
self
.
conv
=
ColumnParallelLinear
(
input_size
=
self
.
L_cache
,
output_size
=
dim
,
bias
=
self
.
bias
,
prefix
=
f
"
{
prefix
}
.conv1d"
,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self
.
conv
.
weight
.
data
=
self
.
conv
.
weight
.
data
.
unsqueeze
(
1
)
self
.
in_proj
=
MergedColumnParallelLinear
(
input_size
=
dim
,
output_sizes
=
[
dim
]
*
3
,
bias
=
self
.
bias
,
prefix
=
f
"
{
prefix
}
.in_proj"
,
)
self
.
out_proj
=
RowParallelLinear
(
input_size
=
dim
,
output_size
=
dim
,
bias
=
self
.
bias
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
assert
envs
.
VLLM_USE_V1
,
(
"ShortConv layers are only supported in V1"
)
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
self
.
kv_cache
=
[(
torch
.
tensor
([]),
)]
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
prefix
=
prefix
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
conv_metadata
:
ShortConvAttentionMetadata
,
):
return
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
conv_metadata
:
ShortConvAttentionMetadata
,
):
torch
.
ops
.
vllm
.
short_conv
(
hidden_states
,
output
,
self
.
prefix
,
)
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
conv_metadata
:
ShortConvAttentionMetadata
,
):
forward_context
=
get_forward_context
()
# ShortConvAttentionMetadata contains metadata necessary for the
# short_conv triton kernels to operate in continuous batching and in
# chunked prefill modes; they are computed at top-level model forward
# since they stay the same and reused for all mamba layers in the same
# iteration.
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
conv_metadata
=
attn_metadata
assert
isinstance
(
attn_metadata
,
ShortConvAttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
has_initial_states_p
=
attn_metadata
.
has_initial_states
BCx
,
_
=
self
.
in_proj
(
hidden_states
)
B
,
C
,
x
=
BCx
.
chunk
(
3
,
dim
=-
1
)
conv_weights
=
self
.
conv
.
weight
.
view
(
self
.
conv
.
weight
.
size
(
0
),
self
.
conv
.
weight
.
size
(
2
))
if
attn_metadata
is
None
:
# V1 profile run
Bx
=
(
B
*
x
).
contiguous
()
hidden_states
=
C
*
Bx
contextualized_states
,
_
=
self
.
out_proj
(
hidden_states
)
return
contextualized_states
num_prefills
=
attn_metadata
.
num_prefills
# request count
num_decodes
=
attn_metadata
.
num_decode_tokens
# token count (=request)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
# token count
has_prefill
=
num_prefills
>
0
has_decode
=
num_decodes
>
0
num_actual_tokens
=
num_decodes
+
num_prefill_tokens
# NOTE: V1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
B_d
,
B_p
=
torch
.
split
(
B
[:
num_actual_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
C_d
,
C_p
=
torch
.
split
(
C
[:
num_actual_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
x_d
,
x_p
=
torch
.
split
(
x
[:
num_actual_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
# Split along batch dimension
state_indices_tensor_d
,
state_indices_tensor_p
=
torch
.
split
(
state_indices_tensor
,
[
num_decodes
,
num_prefills
],
dim
=
0
,
)
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decodes
if
has_prefill
else
None
)
conv_output_list
=
[]
if
has_prefill
:
Bx_p
=
(
B_p
*
x_p
).
transpose
(
0
,
1
)
if
conv_metadata
.
cu_seqlen
is
None
:
conv_metadata
=
update_metadata
(
Bx_p
,
query_start_loc_p
,
conv_metadata
)
Bx
=
causal_conv1d_fn
(
Bx_p
,
conv_weights
,
self
.
conv
.
bias
,
activation
=
None
,
conv_states
=
conv_state
,
has_initial_state
=
has_initial_states_p
,
cache_indices
=
state_indices_tensor_p
,
metadata
=
conv_metadata
,
query_start_loc
=
query_start_loc_p
).
transpose
(
0
,
1
)[:
num_prefill_tokens
]
y
=
C_p
*
Bx
conv_output_list
.
append
(
y
)
if
has_decode
:
Bx_d
=
(
B_d
*
x_d
).
contiguous
()
Bx
=
causal_conv1d_update
(
Bx_d
,
conv_state
,
conv_weights
,
self
.
conv
.
bias
,
activation
=
None
,
conv_state_indices
=
state_indices_tensor_d
)
y
=
C_d
*
Bx
conv_output_list
.
insert
(
0
,
y
)
# Merge prefill and decode outputs before passing to gated MLP
hidden_states
=
torch
.
vstack
(
conv_output_list
)
# Final linear projection
output
[:
num_actual_tokens
],
_
=
self
.
out_proj
(
hidden_states
)
def
get_state_dtype
(
self
)
->
tuple
[
torch
.
dtype
,
...]:
assert
self
.
model_config
is
not
None
assert
self
.
cache_config
is
not
None
return
MambaStateDtypeCalculator
.
short_conv_state_dtype
(
self
.
model_config
.
dtype
,
self
.
cache_config
.
mamba_cache_dtype
,
)
def
get_state_shape
(
self
)
->
tuple
[
tuple
[
int
,
...]]:
return
MambaStateShapeCalculator
.
short_conv_state_shape
(
tp_world_size
=
get_tensor_model_parallel_world_size
(),
intermediate_size
=
self
.
conv_dim
,
conv_kernel
=
self
.
L_cache
,
)
@
property
def
mamba_type
(
self
)
->
str
:
return
"short_conv"
def
short_conv
(
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
forward_cuda
(
hidden_states
=
hidden_states
,
output
=
output
,
conv_metadata
=
None
)
def
short_conv_fake
(
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
return
direct_register_custom_op
(
op_name
=
"short_conv"
,
op_func
=
short_conv
,
mutates_args
=
[
"output"
],
fake_impl
=
short_conv_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
vllm/model_executor/models/lfm2.py
0 → 100644
View file @
2e2000f3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
typing
import
Any
,
Optional
import
torch
import
torch.nn
as
nn
from
transformers
import
Lfm2Config
from
vllm
import
envs
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
from
vllm.model_executor.layers.mamba.short_conv
import
ShortConv
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
(
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
,
SupportsQuant
)
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
Lfm2MLP
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
ff_dim
:
int
,
multiple_of
:
int
,
auto_adjust_ff_dim
:
bool
,
ffn_dim_multiplier
:
Optional
[
float
],
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
if
auto_adjust_ff_dim
:
ff_dim
=
int
(
2
*
ff_dim
/
3
)
# custom dim factor multiplier
if
ffn_dim_multiplier
is
not
None
:
ff_dim
=
int
(
ffn_dim_multiplier
*
ff_dim
)
ff_dim
=
multiple_of
*
((
ff_dim
+
multiple_of
-
1
)
//
multiple_of
)
self
.
w1
=
MergedColumnParallelLinear
(
input_size
=
dim
,
output_sizes
=
[
ff_dim
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
w2
=
RowParallelLinear
(
input_size
=
ff_dim
,
output_size
=
dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gate_up
,
_
=
self
.
w1
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
w2
(
x
)
return
x
class
Lfm2Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Lfm2Config
,
layer_idx
:
int
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
hidden_size
=
hidden_size
self
.
num_kv_heads
=
num_kv_heads
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
self
.
hidden_size
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
out_proj
=
RowParallelLinear
(
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
output_size
=
self
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
self
.
max_position_embeddings
,
base
=
self
.
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
True
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
self
.
q_layernorm
=
RMSNorm
(
self
.
head_dim
,
eps
=
config
.
norm_eps
)
self
.
k_layernorm
=
RMSNorm
(
self
.
head_dim
,
eps
=
config
.
norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
n_tokens
,
_
=
hidden_states
.
shape
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
=
q
.
view
(
n_tokens
,
self
.
num_heads
,
self
.
head_dim
).
contiguous
()
k
=
k
.
view
(
n_tokens
,
self
.
num_kv_heads
,
self
.
head_dim
).
contiguous
()
q
=
self
.
q_layernorm
(
q
)
k
=
self
.
k_layernorm
(
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
=
q
.
view
(
n_tokens
,
self
.
num_heads
*
self
.
head_dim
)
k
=
k
.
view
(
n_tokens
,
self
.
num_kv_heads
*
self
.
head_dim
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
class
Lfm2AttentionDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Lfm2Config
,
layer_idx
:
int
,
model_config
:
Optional
[
ModelConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
prefix
=
prefix
self
.
config
=
config
self
.
layer_idx
=
layer_idx
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
getattr
(
config
,
"original_max_position_embeddings"
,
None
):
rope_scaling
[
"original_max_position_embeddings"
]
=
(
config
.
original_max_position_embeddings
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
self_attn
=
Lfm2Attention
(
config
=
config
,
layer_idx
=
layer_idx
,
hidden_size
=
config
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
feed_forward
=
Lfm2MLP
(
dim
=
config
.
block_dim
,
ff_dim
=
config
.
block_ff_dim
,
multiple_of
=
config
.
block_multiple_of
,
auto_adjust_ff_dim
=
config
.
block_auto_adjust_ff_dim
,
ffn_dim_multiplier
=
config
.
block_ffn_dim_multiplier
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.feed_forward"
,
)
self
.
operator_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_eps
)
self
.
ffn_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
**
kwargs
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
operator_norm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
operator_norm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
)
hidden_states
,
residual
=
self
.
ffn_norm
(
hidden_states
,
residual
)
return
self
.
feed_forward
(
hidden_states
),
residual
class
Lfm2ShortConvDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Lfm2Config
,
layer_idx
:
int
,
model_config
:
Optional
[
ModelConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
conv
=
ShortConv
(
config
=
config
,
dim
=
config
.
conv_dim
,
layer_idx
=
layer_idx
,
model_config
=
model_config
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.conv"
,
)
self
.
feed_forward
=
Lfm2MLP
(
dim
=
config
.
block_dim
,
ff_dim
=
config
.
block_ff_dim
,
multiple_of
=
config
.
block_multiple_of
,
auto_adjust_ff_dim
=
config
.
block_auto_adjust_ff_dim
,
ffn_dim_multiplier
=
config
.
block_ffn_dim_multiplier
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.feed_forward"
,
)
self
.
operator_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_eps
)
self
.
ffn_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
**
kwargs
,
):
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
operator_norm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
operator_norm
(
hidden_states
,
residual
)
output
=
torch
.
empty_like
(
hidden_states
)
self
.
conv
(
hidden_states
,
output
,
conv_metadata
=
None
,
)
hidden_states
,
residual
=
self
.
ffn_norm
(
output
,
residual
)
hidden_states
=
self
.
feed_forward
(
hidden_states
)
return
hidden_states
,
residual
@
support_torch_compile
class
Lfm2Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
model_config
=
vllm_config
.
model_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
lora_vocab
=
((
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
)
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
)
def
get_layer
(
prefix
:
str
):
layer_idx
=
extract_layer_index
(
prefix
)
is_attn
=
self
.
config
.
layer_types
[
layer_idx
]
==
"full_attention"
layer_class
=
(
Lfm2AttentionDecoderLayer
if
is_attn
else
Lfm2ShortConvDecoderLayer
)
return
layer_class
(
config
,
layer_idx
,
model_config
,
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
get_layer
,
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
if
get_pp_group
().
is_last_rank
:
self
.
embedding_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_eps
)
else
:
self
.
embedding_norm
=
PPMissingLayer
()
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
embedding_norm
(
hidden_states
,
residual
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".w1"
,
".w1"
,
0
),
(
".w1"
,
".w3"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
Lfm2ForCausalLM
(
nn
.
Module
,
HasInnerState
,
SupportsLoRA
,
SupportsPP
,
IsHybrid
,
SupportsQuant
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"w1"
:
[
"w1"
,
"w3"
,
],
}
# LoRA specific attributes
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[
"lm_head"
]
@
classmethod
def
get_mamba_state_dtype_from_config
(
cls
,
vllm_config
:
"VllmConfig"
,
)
->
tuple
[
torch
.
dtype
,
...]:
return
MambaStateDtypeCalculator
.
short_conv_state_dtype
(
vllm_config
.
model_config
.
dtype
,
vllm_config
.
cache_config
.
mamba_cache_dtype
,
)
@
classmethod
def
get_mamba_state_shape_from_config
(
cls
,
vllm_config
:
"VllmConfig"
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
]]:
""" Calculate shapes for LFM2's convolutional cache.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
"""
parallel_config
=
vllm_config
.
parallel_config
hf_config
=
vllm_config
.
model_config
.
hf_config
return
MambaStateShapeCalculator
.
short_conv_state_shape
(
tp_world_size
=
parallel_config
.
tensor_parallel_size
,
intermediate_size
=
hf_config
.
conv_dim
,
conv_kernel
=
hf_config
.
conv_L_cache
,
use_v1
=
use_v1
,
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
cache_config
=
vllm_config
.
cache_config
lora_config
=
vllm_config
.
lora_config
scheduler_config
=
vllm_config
.
scheduler_config
assert
(
not
cache_config
.
enable_prefix_caching
),
"Lfm2 currently does not support prefix caching"
assert
envs
.
VLLM_USE_V1
,
(
"Lfm2ForCausalLM doesn't support vLLM v0. Please enable v1"
)
super
().
__init__
()
self
.
config
=
config
self
.
vllm_config
=
vllm_config
self
.
scheduler_config
=
scheduler_config
self
.
model_config
=
vllm_config
.
model_config
self
.
model
=
Lfm2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
self
.
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
self
.
lm_head
=
self
.
lm_head
.
tie_weights
(
self
.
model
.
embed_tokens
)
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
)
\ No newline at end of file
vllm/model_executor/models/registry.py
View file @
2e2000f3
...
...
@@ -93,6 +93,7 @@ _TEXT_GENERATION_MODELS = {
"InternLM3ForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"JAISLMHeadModel"
:
(
"jais"
,
"JAISLMHeadModel"
),
"JambaForCausalLM"
:
(
"jamba"
,
"JambaForCausalLM"
),
"Lfm2ForCausalLM"
:
(
"lfm2"
,
"Lfm2ForCausalLM"
),
"LlamaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"Llama4ForCausalLM"
:
(
"llama4"
,
"Llama4ForCausalLM"
),
# noqa: E501
# For decapoda-research/llama-*
...
...
vllm/v1/attention/backends/mamba_selectors.py
View file @
2e2000f3
...
...
@@ -4,6 +4,8 @@ from vllm.attention.backends.abstract import AttentionBackend
from
vllm.v1.attention.backends.linear_attn
import
LinearAttentionBackend
from
vllm.v1.attention.backends.mamba1_attn
import
Mamba1AttentionBackend
from
vllm.v1.attention.backends.mamba2_attn
import
Mamba2AttentionBackend
from
vllm.v1.attention.backends.short_conv_attn
import
(
ShortConvAttentionBackend
)
def
get_mamba_attn_backend
(
mamba_type
:
str
)
->
type
[
AttentionBackend
]:
...
...
@@ -13,6 +15,8 @@ def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
return
Mamba2AttentionBackend
if
mamba_type
==
"linear_attention"
:
return
LinearAttentionBackend
if
mamba_type
==
"short_conv"
:
return
ShortConvAttentionBackend
raise
NotImplementedError
(
f
"Mamba Attention type
{
mamba_type
}
is not "
"supported yet."
)
vllm/v1/attention/backends/short_conv_attn.py
0 → 100644
View file @
2e2000f3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
ClassVar
,
Optional
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.config
import
VllmConfig
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
split_decodes_and_prefills
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
MambaSpec
class
ShortConvAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_builder_cls
()
->
type
[
"ShortConvAttentionMetadataBuilder"
]:
return
ShortConvAttentionMetadataBuilder
@
dataclass
class
ShortConvAttentionMetadata
:
num_prefills
:
int
num_prefill_tokens
:
int
num_decodes
:
int
num_decode_tokens
:
int
query_start_loc
:
torch
.
Tensor
has_initial_states
:
torch
.
Tensor
state_indices_tensor
:
torch
.
Tensor
# shape: [batch,]
# For causal_conv1d
nums_dict
:
Optional
[
dict
]
=
None
cu_seqlen
:
Optional
[
int
]
=
None
batch_ptr
:
Optional
[
torch
.
tensor
]
=
None
token_chunk_offset_ptr
:
Optional
[
torch
.
tensor
]
=
None
class
ShortConvAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
ShortConvAttentionMetadata
]):
reorder_batch_threshold
:
ClassVar
[
int
]
=
1
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
assert
isinstance
(
kv_cache_spec
,
MambaSpec
)
self
.
kv_cache_spec
=
kv_cache_spec
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
)
->
ShortConvAttentionMetadata
:
num_reqs
=
common_attn_metadata
.
num_reqs
query_start_loc
=
common_attn_metadata
.
query_start_loc
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
[:,
0
]
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
1
))
has_initial_states
=
None
if
num_prefills
>
0
:
#[batch,]
has_initial_states_cpu
=
(
common_attn_metadata
.
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
>
0
)
has_initial_states
=
has_initial_states_cpu
.
to
(
query_start_loc
.
device
)
attn_metadata
=
ShortConvAttentionMetadata
(
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
query_start_loc
=
query_start_loc
,
has_initial_states
=
has_initial_states
,
state_indices_tensor
=
state_indices_tensor
,
)
return
attn_metadata
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment