Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bbf81f9a
Unverified
Commit
bbf81f9a
authored
Mar 01, 2026
by
Asaf Gardin
Committed by
GitHub
Mar 01, 2026
Browse files
[Mamba1] - Kernel Level Chunk Alignment for Prefix Caching (#34798)
Signed-off-by:
Josephasafg
<
ajgard7@gmail.com
>
parent
da543d1a
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
251 additions
and
146 deletions
+251
-146
csrc/mamba/mamba_ssm/selective_scan.h
csrc/mamba/mamba_ssm/selective_scan.h
+3
-1
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
+68
-35
csrc/ops.h
csrc/ops.h
+3
-1
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+3
-1
tests/kernels/mamba/test_mamba_ssm.py
tests/kernels/mamba/test_mamba_ssm.py
+4
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+4
-0
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+4
-0
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+4
-0
vllm/v1/attention/backends/mamba1_attn.py
vllm/v1/attention/backends/mamba1_attn.py
+31
-2
vllm/v1/attention/backends/mamba2_attn.py
vllm/v1/attention/backends/mamba2_attn.py
+6
-106
vllm/v1/attention/backends/mamba_attn.py
vllm/v1/attention/backends/mamba_attn.py
+121
-0
No files found.
csrc/mamba/mamba_ssm/selective_scan.h
View file @
bbf81f9a
...
...
@@ -17,7 +17,7 @@
struct
SSMParamsBase
{
using
index_t
=
size_t
;
int
batch
,
dim
,
seqlen
,
dstate
,
n_groups
,
n_chunks
;
int
batch
,
dim
,
seqlen
,
dstate
,
n_groups
;
int
dim_ngroups_ratio
;
bool
is_variable_B
;
bool
is_variable_C
;
...
...
@@ -72,6 +72,8 @@ struct SSMParamsBase {
void
*
__restrict__
block_idx_first_scheduled_token_ptr
;
// (batch,) - first block to write
void
*
__restrict__
block_idx_last_scheduled_token_ptr
;
// (batch,) - last block to write
void
*
__restrict__
initial_state_idx_ptr
;
// (batch,) - index of the initial state to use
void
*
__restrict__
cu_chunk_seqlen_ptr
;
// (nchunks+1,) - cumulative chunk token offsets
void
*
__restrict__
last_chunk_indices_ptr
;
// (batch,) - index of last chunk per sequence
};
...
...
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
View file @
bbf81f9a
...
...
@@ -81,7 +81,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
constexpr
bool
kIsVariableC
=
Ktraits
::
kIsVariableC
;
constexpr
bool
kHasZ
=
Ktraits
::
kHasZ
;
constexpr
bool
kVarlen
=
Ktraits
::
kVarlen
;
constexpr
int
kNThreads
=
Ktraits
::
kNThreads
;
constexpr
int
kNItems
=
Ktraits
::
kNItems
;
constexpr
int
kNRows
=
Ktraits
::
kNRows
;
constexpr
bool
kDirectIO
=
Ktraits
::
kDirectIO
;
...
...
@@ -161,17 +160,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
// }
constexpr
int
kChunkSize
=
kNThreads
*
kNItems
;
// Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
const
int
iteration_chunk_size
=
params
.
cache_enabled
?
params
.
block_size
:
2048
;
const
int
n_chunks
=
(
seqlen
+
iteration_chunk_size
-
1
)
/
iteration_chunk_size
;
const
int
block_size
=
params
.
cache_enabled
?
params
.
block_size
:
2048
;
const
int
*
batch_cache_indices
=
cache_indices
!=
nullptr
?
cache_indices
+
batch_id
*
params
.
cache_indices_stride
:
nullptr
;
...
...
@@ -181,10 +171,44 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
reinterpret_cast
<
const
int
*>
(
params
.
block_idx_last_scheduled_token_ptr
)
:
nullptr
;
const
int
*
initial_state_idx
=
params
.
initial_state_idx_ptr
!=
nullptr
?
reinterpret_cast
<
const
int
*>
(
params
.
initial_state_idx_ptr
)
:
nullptr
;
const
int
*
cu_chunk_seqlen
=
params
.
cu_chunk_seqlen_ptr
!=
nullptr
?
reinterpret_cast
<
const
int
*>
(
params
.
cu_chunk_seqlen_ptr
)
:
nullptr
;
const
int
*
last_chunk_indices
=
params
.
last_chunk_indices_ptr
!=
nullptr
?
reinterpret_cast
<
const
int
*>
(
params
.
last_chunk_indices_ptr
)
:
nullptr
;
const
size_t
load_cache_slot
=
params
.
cache_enabled
&&
batch_cache_indices
!=
nullptr
?
batch_cache_indices
[
initial_state_idx
[
batch_id
]]
:
cache_index
;
const
int
block_idx_first
=
(
params
.
cache_enabled
&&
block_idx_first_scheduled
!=
nullptr
)
?
block_idx_first_scheduled
[
batch_id
]
:
0
;
// Determine chunk boundaries from pre-computed metadata (APC mode)
// or fall back to simple block_size chunking.
int
first_chunk_idx
,
n_chunks
;
int
current_position
;
if
(
cu_chunk_seqlen
!=
nullptr
&&
last_chunk_indices
!=
nullptr
)
{
const
int
last_chunk_idx
=
last_chunk_indices
[
batch_id
];
first_chunk_idx
=
(
batch_id
==
0
)
?
0
:
last_chunk_indices
[
batch_id
-
1
]
+
1
;
n_chunks
=
last_chunk_idx
-
first_chunk_idx
+
1
;
// Derive current_position: if the first chunk is partial (fills remainder
// of a started block), offset into the block accordingly.
const
int
first_chunk_tokens
=
cu_chunk_seqlen
[
first_chunk_idx
+
1
]
-
cu_chunk_seqlen
[
first_chunk_idx
];
const
int
chunk_start_offset
=
(
n_chunks
>
1
&&
first_chunk_tokens
<
block_size
)
?
(
block_size
-
first_chunk_tokens
)
:
0
;
current_position
=
block_idx_first
*
block_size
+
chunk_start_offset
;
}
else
{
first_chunk_idx
=
0
;
n_chunks
=
(
seqlen
+
block_size
-
1
)
/
block_size
;
current_position
=
0
;
}
int
tokens_processed
=
0
;
for
(
int
chunk
=
0
;
chunk
<
n_chunks
;
++
chunk
)
{
const
int
chunk_tokens
=
(
cu_chunk_seqlen
!=
nullptr
)
?
cu_chunk_seqlen
[
first_chunk_idx
+
chunk
+
1
]
-
cu_chunk_seqlen
[
first_chunk_idx
+
chunk
]
:
min
(
block_size
,
seqlen
-
tokens_processed
);
if
(
chunk_tokens
<=
0
)
break
;
input_t
u_vals
[
kNRows
][
kNItems
],
delta_vals_load
[
kNRows
][
kNItems
];
__syncthreads
();
...
...
@@ -193,12 +217,12 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if
constexpr
(
!
kDirectIO
)
{
if
(
r
>
0
)
{
__syncthreads
();
}
}
load_input
<
Ktraits
>
(
u
+
r
*
params
.
u_d_stride
,
u_vals
[
r
],
smem_load
,
seqlen
-
chunk
*
kChunkSize
);
load_input
<
Ktraits
>
(
u
+
r
*
params
.
u_d_stride
,
u_vals
[
r
],
smem_load
,
chunk_tokens
);
if
constexpr
(
!
kDirectIO
)
{
__syncthreads
();
}
load_input
<
Ktraits
>
(
delta
+
r
*
params
.
delta_d_stride
,
delta_vals_load
[
r
],
smem_load
,
seqlen
-
chunk
*
kChunkSize
);
load_input
<
Ktraits
>
(
delta
+
r
*
params
.
delta_d_stride
,
delta_vals_load
[
r
],
smem_load
,
chunk_tokens
);
}
u
+=
kC
hunk
Size
;
delta
+=
kC
hunk
Size
;
u
+=
c
hunk
_tokens
;
delta
+=
c
hunk
_tokens
;
float
delta_vals
[
kNRows
][
kNItems
],
delta_u_vals
[
kNRows
][
kNItems
],
out_vals
[
kNRows
][
kNItems
];
#pragma unroll
...
...
@@ -232,7 +256,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
weight_t
B_vals
[
kNItems
],
C_vals
[
kNItems
];
if
constexpr
(
kIsVariableB
)
{
load_weight
<
Ktraits
>
(
Bvar
+
state_idx
*
params
.
B_dstate_stride
,
B_vals
,
smem_load_weight
,
(
seqlen
-
chunk
*
kChunkSize
)
*
(
1
)
);
smem_load_weight
,
chunk_tokens
);
if
constexpr
(
!
kIsVariableC
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
...
...
@@ -243,7 +267,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if
constexpr
(
kIsVariableC
)
{
auto
&
smem_load_weight_C
=
!
kIsVariableB
?
smem_load_weight
:
smem_load_weight1
;
load_weight
<
Ktraits
>
(
Cvar
+
state_idx
*
params
.
C_dstate_stride
,
C_vals
,
smem_load_weight_C
,
(
seqlen
-
chunk
*
kChunkSize
)
*
(
1
)
);
smem_load_weight_C
,
chunk_tokens
);
if
constexpr
(
!
kIsVariableB
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
...
...
@@ -266,10 +290,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
for
(
int
i
=
0
;
i
<
kNItems
;
++
i
)
{
thread_data
[
i
]
=
make_float2
(
exp2f
(
delta_vals
[
r
][
i
]
*
A_val
[
r
]),
!
kIsVariableB
?
delta_u_vals
[
r
][
i
]
:
B_vals
[
i
]
*
delta_u_vals
[
r
][
i
]);
if
(
seqlen
%
(
kNItems
*
kNThreads
)
!=
0
)
{
// So that the last state is correct
if
(
threadIdx
.
x
*
kNItems
+
i
>=
seqlen
-
chunk
*
kChunkSize
)
{
thread_data
[
i
]
=
make_float2
(
1.
f
,
0.
f
);
}
if
(
threadIdx
.
x
*
kNItems
+
i
>=
chunk_tokens
)
{
thread_data
[
i
]
=
make_float2
(
1.
f
,
0.
f
);
}
}
// Initialize running total
...
...
@@ -301,14 +323,14 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if
(
threadIdx
.
x
==
0
)
{
smem_running_prefix
[
state_idx
+
r
*
MAX_DSTATE
]
=
prefix_op
.
running_prefix
;
// Store state at the end of each chunk when cache is enabled
// Store state at the end of each
aligned
chunk when cache is enabled
if
(
params
.
cache_enabled
&&
batch_cache_indices
!=
nullptr
)
{
size_t
cache_slot
;
if
(
chunk
==
n_chunks
-
1
)
{
cache_slot
=
batch_cache_indices
[
block_idx_last_scheduled
[
batch_id
]];
}
else
{
cache_slot
=
batch_cache_indices
[
block_idx_first_scheduled
[
batch_id
]
+
chunk
];
const
int
block_idx_completed
=
(
current_position
+
chunk_tokens
-
1
)
/
block_size
;
cache_slot
=
batch_cache_indices
[
block_idx_completed
];
}
size_t
state_offset
=
cache_slot
*
params
.
ssm_states_batch_stride
+
...
...
@@ -331,38 +353,41 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
input_t
*
out
=
reinterpret_cast
<
input_t
*>
(
params
.
out_ptr
)
+
sequence_start_index
*
params
.
out_batch_stride
+
dim_id
*
kNRows
*
params
.
out_d_stride
+
chunk
*
kChunkSize
;
+
dim_id
*
kNRows
*
params
.
out_d_stride
+
tokens_processed
;
__syncthreads
();
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
if
constexpr
(
!
kDirectIO
)
{
if
(
r
>
0
)
{
__syncthreads
();
}
}
store_output
<
Ktraits
>
(
out
+
r
*
params
.
out_d_stride
,
out_vals
[
r
],
smem_store
,
seqlen
-
chunk
*
kChunkSize
);
store_output
<
Ktraits
>
(
out
+
r
*
params
.
out_d_stride
,
out_vals
[
r
],
smem_store
,
chunk_tokens
);
}
if
constexpr
(
kHasZ
)
{
input_t
*
z
=
reinterpret_cast
<
input_t
*>
(
params
.
z_ptr
)
+
sequence_start_index
*
params
.
z_batch_stride
+
dim_id
*
kNRows
*
params
.
z_d_stride
+
chunk
*
kChunkSize
;
+
dim_id
*
kNRows
*
params
.
z_d_stride
+
tokens_processed
;
input_t
*
out_z
=
reinterpret_cast
<
input_t
*>
(
params
.
out_z_ptr
)
+
sequence_start_index
*
params
.
out_z_batch_stride
+
dim_id
*
kNRows
*
params
.
out_z_d_stride
+
chunk
*
kChunkSize
;
+
dim_id
*
kNRows
*
params
.
out_z_d_stride
+
tokens_processed
;
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
input_t
z_vals
[
kNItems
];
__syncthreads
();
load_input
<
Ktraits
>
(
z
+
r
*
params
.
z_d_stride
,
z_vals
,
smem_load
,
seqlen
-
chunk
*
kChunkSize
);
load_input
<
Ktraits
>
(
z
+
r
*
params
.
z_d_stride
,
z_vals
,
smem_load
,
chunk_tokens
);
#pragma unroll
for
(
int
i
=
0
;
i
<
kNItems
;
++
i
)
{
float
z_val
=
z_vals
[
i
];
out_vals
[
r
][
i
]
*=
z_val
/
(
1
+
expf
(
-
z_val
));
}
__syncthreads
();
store_output
<
Ktraits
>
(
out_z
+
r
*
params
.
out_z_d_stride
,
out_vals
[
r
],
smem_store
,
seqlen
-
chunk
*
kChunkSize
);
store_output
<
Ktraits
>
(
out_z
+
r
*
params
.
out_z_d_stride
,
out_vals
[
r
],
smem_store
,
chunk_tokens
);
}
}
Bvar
+=
kChunkSize
*
1
;
Cvar
+=
kChunkSize
*
1
;
Bvar
+=
chunk_tokens
;
Cvar
+=
chunk_tokens
;
tokens_processed
+=
chunk_tokens
;
current_position
+=
chunk_tokens
;
}
}
...
...
@@ -506,7 +531,9 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
int64_t
block_size
,
const
std
::
optional
<
torch
::
Tensor
>
&
block_idx_first_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>
&
block_idx_last_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>
&
initial_state_idx
)
{
const
std
::
optional
<
torch
::
Tensor
>
&
initial_state_idx
,
const
std
::
optional
<
torch
::
Tensor
>
&
cu_chunk_seqlen
,
const
std
::
optional
<
torch
::
Tensor
>
&
last_chunk_indices
)
{
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
...
...
@@ -548,6 +575,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
params
.
block_idx_first_scheduled_token_ptr
=
block_idx_first_scheduled_token
.
has_value
()
?
block_idx_first_scheduled_token
.
value
().
data_ptr
()
:
nullptr
;
params
.
block_idx_last_scheduled_token_ptr
=
block_idx_last_scheduled_token
.
has_value
()
?
block_idx_last_scheduled_token
.
value
().
data_ptr
()
:
nullptr
;
params
.
initial_state_idx_ptr
=
initial_state_idx
.
has_value
()
?
initial_state_idx
.
value
().
data_ptr
()
:
nullptr
;
params
.
cu_chunk_seqlen_ptr
=
cu_chunk_seqlen
.
has_value
()
?
cu_chunk_seqlen
.
value
().
data_ptr
()
:
nullptr
;
params
.
last_chunk_indices_ptr
=
last_chunk_indices
.
has_value
()
?
last_chunk_indices
.
value
().
data_ptr
()
:
nullptr
;
// All stride are in elements, not bytes.
params
.
A_d_stride
=
A
.
stride
(
0
);
...
...
@@ -633,7 +662,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
int64_t
block_size
,
const
std
::
optional
<
torch
::
Tensor
>
&
block_idx_first_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>
&
block_idx_last_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>
&
initial_state_idx
)
{
const
std
::
optional
<
torch
::
Tensor
>
&
initial_state_idx
,
const
std
::
optional
<
torch
::
Tensor
>
&
cu_chunk_seqlen
,
const
std
::
optional
<
torch
::
Tensor
>
&
last_chunk_indices
)
{
auto
input_type
=
u
.
scalar_type
();
auto
weight_type
=
A
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
...
...
@@ -778,7 +809,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
block_size
,
block_idx_first_scheduled_token
,
block_idx_last_scheduled_token
,
initial_state_idx
initial_state_idx
,
cu_chunk_seqlen
,
last_chunk_indices
);
...
...
csrc/ops.h
View file @
bbf81f9a
...
...
@@ -371,7 +371,9 @@ void selective_scan_fwd(
const
torch
::
Tensor
&
ssm_states
,
int64_t
pad_slot_id
,
int64_t
block_size
,
const
std
::
optional
<
torch
::
Tensor
>&
block_idx_first_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>&
block_idx_last_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>&
initial_state_idx
);
const
std
::
optional
<
torch
::
Tensor
>&
initial_state_idx
,
const
std
::
optional
<
torch
::
Tensor
>&
cu_chunk_seqlen
,
const
std
::
optional
<
torch
::
Tensor
>&
last_chunk_indices
);
torch
::
Tensor
dynamic_4bit_int_moe_cpu
(
torch
::
Tensor
x
,
torch
::
Tensor
topk_ids
,
torch
::
Tensor
topk_weights
,
...
...
csrc/torch_bindings.cpp
View file @
bbf81f9a
...
...
@@ -640,7 +640,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int block_size,"
"Tensor? block_idx_first_scheduled_token,"
"Tensor? block_idx_last_scheduled_token,"
"Tensor? initial_state_idx) -> ()"
);
"Tensor? initial_state_idx,"
"Tensor? cu_chunk_seqlen,"
"Tensor? last_chunk_indices) -> ()"
);
ops
.
impl
(
"selective_scan_fwd"
,
torch
::
kCUDA
,
&
selective_scan_fwd
);
// Hadamard transforms
...
...
tests/kernels/mamba/test_mamba_ssm.py
View file @
bbf81f9a
...
...
@@ -183,6 +183,8 @@ def selective_scan_opcheck_fn(
block_idx_first_scheduled_token
=
None
,
block_idx_last_scheduled_token
=
None
,
initial_state_idx
=
None
,
cu_chunk_seqlen
=
None
,
last_chunk_indices
=
None
,
):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
...
...
@@ -231,6 +233,8 @@ def selective_scan_opcheck_fn(
block_idx_first_scheduled_token
,
block_idx_last_scheduled_token
,
initial_state_idx
,
cu_chunk_seqlen
,
last_chunk_indices
,
),
test_utils
=
[
"test_schema"
,
"test_faketensor"
],
)
...
...
vllm/_custom_ops.py
View file @
bbf81f9a
...
...
@@ -2021,6 +2021,8 @@ def selective_scan_fwd(
block_idx_first_scheduled_token
:
torch
.
Tensor
|
None
=
None
,
block_idx_last_scheduled_token
:
torch
.
Tensor
|
None
=
None
,
initial_state_idx
:
torch
.
Tensor
|
None
=
None
,
cu_chunk_seqlen
:
torch
.
Tensor
|
None
=
None
,
last_chunk_indices
:
torch
.
Tensor
|
None
=
None
,
):
torch
.
ops
.
_C
.
selective_scan_fwd
(
u
,
...
...
@@ -2041,6 +2043,8 @@ def selective_scan_fwd(
block_idx_first_scheduled_token
,
block_idx_last_scheduled_token
,
initial_state_idx
,
cu_chunk_seqlen
,
last_chunk_indices
,
)
...
...
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
bbf81f9a
...
...
@@ -271,6 +271,8 @@ class MambaMixer(MambaBase, PluggableLayer):
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
cu_chunk_seqlen_p
=
attn_metadata
.
cu_chunk_seqlen_p
last_chunk_indices_p
=
attn_metadata
.
last_chunk_indices_p
# 1. Gated MLP's linear projection
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
...
...
@@ -376,6 +378,8 @@ class MambaMixer(MambaBase, PluggableLayer):
block_idx_first_scheduled_token
=
block_idx_first_scheduled_token_p
,
block_idx_last_scheduled_token
=
block_idx_last_scheduled_token_p
,
initial_state_idx
=
block_idx_last_computed_token_p
,
cu_chunk_seqlen
=
cu_chunk_seqlen_p
,
last_chunk_indices
=
last_chunk_indices_p
,
)
ssm_outputs
.
append
(
scan_out_p
)
...
...
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
View file @
bbf81f9a
...
...
@@ -497,6 +497,8 @@ def selective_scan_fn(
block_idx_first_scheduled_token
=
None
,
block_idx_last_scheduled_token
=
None
,
initial_state_idx
=
None
,
cu_chunk_seqlen
=
None
,
last_chunk_indices
=
None
,
)
->
torch
.
Tensor
:
"""
u: (dim, total_length) for varlen or (batch, dim, seqlen)
...
...
@@ -588,6 +590,8 @@ def selective_scan_fn(
block_idx_first_scheduled_token
,
block_idx_last_scheduled_token
,
initial_state_idx
,
cu_chunk_seqlen
,
last_chunk_indices
,
)
if
z
is
None
:
...
...
vllm/v1/attention/backends/mamba1_attn.py
View file @
bbf81f9a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
replace
from
typing
import
Any
from
vllm.v1.attention.backend
import
AttentionBackend
from
vllm.v1.attention.backend
import
AttentionBackend
,
CommonAttentionMetadata
from
vllm.v1.attention.backends.mamba_attn
import
(
BaseMambaAttentionMetadata
,
BaseMambaAttentionMetadataBuilder
,
...
...
@@ -29,3 +30,31 @@ class Mamba1AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder
[
Mamba1AttentionMetadata
]
):
metadata_cls
=
Mamba1AttentionMetadata
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
**
kwargs
:
Any
,
)
->
Mamba1AttentionMetadata
:
common
=
self
.
_compute_common_metadata
(
common_attn_metadata
)
if
(
common
.
num_prefills
>
0
and
self
.
vllm_config
.
cache_config
.
mamba_cache_mode
==
"all"
):
cu_chunk_seqlen_p
,
_
,
last_chunk_indices_p
=
(
self
.
_build_chunk_metadata_tensors
(
self
.
kv_cache_spec
.
block_size
,
common
,
common_attn_metadata
,
)
)
return
replace
(
common
,
cu_chunk_seqlen_p
=
cu_chunk_seqlen_p
,
last_chunk_indices_p
=
last_chunk_indices_p
,
)
return
common
vllm/v1/attention/backends/mamba2_attn.py
View file @
bbf81f9a
...
...
@@ -7,7 +7,6 @@ from typing import Any
import
torch
from
vllm.config
import
VllmConfig
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
CommonAttentionMetadata
,
...
...
@@ -105,14 +104,6 @@ class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
# Chunk-related metadata (only for prefill)
seq_idx_p
:
torch
.
Tensor
|
None
=
None
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
# each chunk, its offsets into the varlen sequence dimension. It is defined
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
# cu_chunk_seqlen_p[i+1].
cu_chunk_seqlen_p
:
torch
.
Tensor
|
None
=
None
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
# index of the last chunk for every sequence in the (prefill) batch.
last_chunk_indices_p
:
torch
.
Tensor
|
None
=
None
class
Mamba2AttentionMetadataBuilder
(
...
...
@@ -134,68 +125,6 @@ class Mamba2AttentionMetadataBuilder(
)
self
.
chunk_size
:
int
=
chunk_size
def
_compute_chunk_metadata
(
self
,
num_prefills
:
int
,
num_computed_tokens_p_cpu
:
torch
.
Tensor
,
query_start_loc_p_cpu
:
torch
.
Tensor
,
)
->
tuple
[
list
[
int
],
list
[
int
],
list
[
int
]]:
"""
Compute chunk-specific metadata for Mamba2.
The code below carefully constructs the chunks such that:
1. Chunks contain tokens from a *single* sequence only.
2. For every sequence, we are guaranteed that we can
retrieve the mamba state *every* chunk_size tokens.
Constraint (1) dramatically simplifies the mamba2 kernels.
Constraint (2) dramatically simplifies the implementation
of prefix caching for mamba2 (wip). We need to take care
of the interaction with chunked prefill in order to
satisfy constraint (2).
"""
# TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen
=
[]
seq_idx
=
[]
last_chunk_indices
=
[]
seqlen_pos
=
0
for
req_idx
in
range
(
num_prefills
):
this_num_computed
=
num_computed_tokens_p_cpu
[
req_idx
].
item
()
this_new_tokens
=
(
query_start_loc_p_cpu
[
req_idx
+
1
].
item
()
-
query_start_loc_p_cpu
[
req_idx
].
item
()
)
# if computed tokens are not chunk-aligned, use the first
# chunk to finish it off
if
this_num_computed
%
self
.
chunk_size
!=
0
:
seq_idx
.
append
(
req_idx
)
cu_chunk_seqlen
.
append
(
seqlen_pos
)
# how many tokens to finish the chunk?
chunk_len
=
(
cdiv
(
this_num_computed
,
self
.
chunk_size
)
*
self
.
chunk_size
-
this_num_computed
)
# we can only use at most this_new_tokens
chunk_len
=
min
(
chunk_len
,
this_new_tokens
)
seqlen_pos
+=
chunk_len
this_new_tokens
-=
chunk_len
n_chunks
=
cdiv
(
this_new_tokens
,
self
.
chunk_size
)
for
chunk
in
range
(
n_chunks
):
seq_idx
.
append
(
req_idx
)
cu_chunk_seqlen
.
append
(
seqlen_pos
)
chunk_len
=
min
(
self
.
chunk_size
,
this_new_tokens
)
seqlen_pos
+=
chunk_len
this_new_tokens
-=
chunk_len
assert
this_new_tokens
==
0
last_chunk_indices
.
append
(
len
(
cu_chunk_seqlen
)
-
1
)
cu_chunk_seqlen
.
append
(
seqlen_pos
)
return
cu_chunk_seqlen
,
seq_idx
,
last_chunk_indices
def
build
(
self
,
common_prefix_len
:
int
,
...
...
@@ -220,41 +149,12 @@ class Mamba2AttentionMetadataBuilder(
else
False
)
num_reqs
=
common
.
num_reqs
num_prefills
=
common
.
num_prefills
num_decode_tokens
=
common
.
num_decode_tokens
num_computed_tokens_cpu
=
(
common_attn_metadata
.
compute_num_computed_tokens
().
cpu
()
)
num_computed_tokens_p_cpu
=
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
query_start_loc_p_cpu
=
(
common_attn_metadata
.
query_start_loc_cpu
[
-
num_prefills
-
1
:]
-
num_decode_tokens
)
cu_chunk_seqlen
,
seq_idx
,
last_chunk_indices
=
self
.
_compute_chunk_metadata
(
num_prefills
,
num_computed_tokens_p_cpu
,
query_start_loc_p_cpu
,
)
seq_idx_p
=
torch
.
as_tensor
(
seq_idx
,
device
=
common_attn_metadata
.
query_start_loc
.
device
,
dtype
=
torch
.
int32
,
)
cu_chunk_seqlen_p
=
torch
.
as_tensor
(
cu_chunk_seqlen
,
device
=
common_attn_metadata
.
query_start_loc
.
device
,
dtype
=
torch
.
int32
,
)
last_chunk_indices_p
=
torch
.
as_tensor
(
last_chunk_indices
,
device
=
common_attn_metadata
.
query_start_loc
.
device
,
dtype
=
torch
.
int32
,
cu_chunk_seqlen_p
,
seq_idx_p
,
last_chunk_indices_p
=
(
self
.
_build_chunk_metadata_tensors
(
self
.
chunk_size
,
common
,
common_attn_metadata
,
)
)
return
replace
(
...
...
vllm/v1/attention/backends/mamba_attn.py
View file @
bbf81f9a
...
...
@@ -59,6 +59,15 @@ class BaseMambaAttentionMetadata:
# The following tensor is only used for prefix caching in align mode
seq_lens
:
torch
.
Tensor
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
# each chunk, its offsets into the varlen sequence dimension. It is defined
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
# cu_chunk_seqlen_p[i+1].
cu_chunk_seqlen_p
:
torch
.
Tensor
|
None
=
None
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
# index of the last chunk for every sequence in the (prefill) batch.
last_chunk_indices_p
:
torch
.
Tensor
|
None
=
None
# The following attributes are for triton implementation of causal_conv1d
nums_dict
:
dict
|
None
=
None
batch_ptr
:
torch
.
Tensor
|
None
=
None
...
...
@@ -185,6 +194,118 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
common_attn_metadata
,
num_accepted_tokens
=
num_accepted_tokens
)
def
_compute_chunk_metadata
(
self
,
chunk_size
:
int
,
num_prefills
:
int
,
num_computed_tokens_p_cpu
:
torch
.
Tensor
,
query_start_loc_p_cpu
:
torch
.
Tensor
,
)
->
tuple
[
list
[
int
],
list
[
int
],
list
[
int
]]:
"""
Compute chunk-specific metadata for Mamba models.
The code below carefully constructs the chunks such that:
1. Chunks contain tokens from a *single* sequence only.
2. For every sequence, we are guaranteed that we can
retrieve the mamba state *every* chunk_size tokens.
Constraint (1) dramatically simplifies the mamba kernels.
Constraint (2) dramatically simplifies the implementation
of prefix caching for mamba (wip). We need to take care
of the interaction with chunked prefill in order to
satisfy constraint (2).
"""
# TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen
=
[]
seq_idx
=
[]
last_chunk_indices
=
[]
seqlen_pos
=
0
for
req_idx
in
range
(
num_prefills
):
this_num_computed
=
num_computed_tokens_p_cpu
[
req_idx
].
item
()
this_new_tokens
=
(
query_start_loc_p_cpu
[
req_idx
+
1
].
item
()
-
query_start_loc_p_cpu
[
req_idx
].
item
()
)
# if computed tokens are not chunk-aligned, use the first
# chunk to finish it off
if
this_num_computed
%
chunk_size
!=
0
:
seq_idx
.
append
(
req_idx
)
cu_chunk_seqlen
.
append
(
seqlen_pos
)
# how many tokens to finish the chunk?
chunk_len
=
(
cdiv
(
this_num_computed
,
chunk_size
)
*
chunk_size
-
this_num_computed
)
# we can only use at most this_new_tokens
chunk_len
=
min
(
chunk_len
,
this_new_tokens
)
seqlen_pos
+=
chunk_len
this_new_tokens
-=
chunk_len
n_chunks
=
cdiv
(
this_new_tokens
,
chunk_size
)
for
chunk
in
range
(
n_chunks
):
seq_idx
.
append
(
req_idx
)
cu_chunk_seqlen
.
append
(
seqlen_pos
)
chunk_len
=
min
(
chunk_size
,
this_new_tokens
)
seqlen_pos
+=
chunk_len
this_new_tokens
-=
chunk_len
assert
this_new_tokens
==
0
last_chunk_indices
.
append
(
len
(
cu_chunk_seqlen
)
-
1
)
cu_chunk_seqlen
.
append
(
seqlen_pos
)
return
cu_chunk_seqlen
,
seq_idx
,
last_chunk_indices
def
_build_chunk_metadata_tensors
(
self
,
chunk_size
:
int
,
common
:
M
,
common_attn_metadata
:
CommonAttentionMetadata
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Compute chunk metadata and return as device tensors.
Returns (cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p).
"""
num_reqs
=
common
.
num_reqs
num_prefills
=
common
.
num_prefills
num_decode_tokens
=
common
.
num_decode_tokens
num_computed_tokens_cpu
=
(
common_attn_metadata
.
compute_num_computed_tokens
().
cpu
()
)
num_computed_tokens_p_cpu
=
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
query_start_loc_p_cpu
=
(
common_attn_metadata
.
query_start_loc_cpu
[
-
num_prefills
-
1
:]
-
num_decode_tokens
)
cu_chunk_seqlen
,
seq_idx
,
last_chunk_indices
=
self
.
_compute_chunk_metadata
(
chunk_size
,
num_prefills
,
num_computed_tokens_p_cpu
,
query_start_loc_p_cpu
,
)
device
=
common_attn_metadata
.
query_start_loc
.
device
cu_chunk_seqlen_p
=
torch
.
as_tensor
(
cu_chunk_seqlen
,
device
=
device
,
dtype
=
torch
.
int32
,
)
seq_idx_p
=
torch
.
as_tensor
(
seq_idx
,
device
=
device
,
dtype
=
torch
.
int32
,
)
last_chunk_indices_p
=
torch
.
as_tensor
(
last_chunk_indices
,
device
=
device
,
dtype
=
torch
.
int32
,
)
return
cu_chunk_seqlen_p
,
seq_idx_p
,
last_chunk_indices_p
def
_compute_prefix_caching_block_indices
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment