Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fea3e476
"vscode:/vscode.git/clone" did not exist on "cf2f084d56a1293cb08da2393984cdc7685ac019"
Unverified
Commit
fea3e476
authored
Sep 29, 2025
by
Thomas Parnell
Committed by
GitHub
Sep 29, 2025
Browse files
[Kernel] Chunk-aligned mamba2 (#24683)
parent
61a34316
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
250 additions
and
434 deletions
+250
-434
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+4
-4
vllm/model_executor/layers/mamba/ops/ssd_bmm.py
vllm/model_executor/layers/mamba/ops/ssd_bmm.py
+17
-25
vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
+79
-157
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
+19
-29
vllm/model_executor/layers/mamba/ops/ssd_combined.py
vllm/model_executor/layers/mamba/ops/ssd_combined.py
+21
-35
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
+26
-71
vllm/model_executor/models/plamo2.py
vllm/model_executor/models/plamo2.py
+4
-4
vllm/v1/attention/backends/mamba2_attn.py
vllm/v1/attention/backends/mamba2_attn.py
+80
-109
No files found.
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
fea3e476
...
@@ -502,9 +502,9 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -502,9 +502,9 @@ class MambaMixer2(MambaBase, CustomOp):
prep_initial_states
=
attn_metadata
.
prep_initial_states
prep_initial_states
=
attn_metadata
.
prep_initial_states
chunk_size
=
attn_metadata
.
chunk_size
chunk_size
=
attn_metadata
.
chunk_size
seq_idx_p
=
attn_metadata
.
seq_idx_p
seq_idx_p
=
attn_metadata
.
seq_idx_p
chunk_indices_p
=
attn_metadata
.
chunk_indices_p
chunk_offsets_p
=
attn_metadata
.
chunk_offsets_p
query_start_loc_p
=
attn_metadata
.
query_start_loc_p
query_start_loc_p
=
attn_metadata
.
query_start_loc_p
cu_chunk_seqlen_p
=
attn_metadata
.
cu_chunk_seqlen_p
last_chunk_indices_p
=
attn_metadata
.
last_chunk_indices_p
# 1. Gated MLP's linear projection
# 1. Gated MLP's linear projection
projected_states
,
_
=
self
.
in_proj
(
hidden_states
)
projected_states
,
_
=
self
.
in_proj
(
hidden_states
)
...
@@ -634,9 +634,9 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -634,9 +634,9 @@ class MambaMixer2(MambaBase, CustomOp):
z
=
None
,
z
=
None
,
dt_bias
=
self
.
dt_bias
,
dt_bias
=
self
.
dt_bias
,
seq_idx
=
seq_idx_p
,
seq_idx
=
seq_idx_p
,
chunk_indices
=
chunk_indices_p
,
chunk_offsets
=
chunk_offsets_p
,
cu_seqlens
=
query_start_loc_p
,
cu_seqlens
=
query_start_loc_p
,
cu_chunk_seqlens
=
cu_chunk_seqlen_p
,
last_chunk_indices
=
last_chunk_indices_p
,
initial_states
=
initial_states
,
initial_states
=
initial_states
,
dt_softplus
=
True
,
dt_softplus
=
True
,
dt_limit
=
(
0.0
,
float
(
"inf"
)),
dt_limit
=
(
0.0
,
float
(
"inf"
)),
...
...
vllm/model_executor/layers/mamba/ops/ssd_bmm.py
View file @
fea3e476
...
@@ -6,8 +6,6 @@
...
@@ -6,8 +6,6 @@
# ruff: noqa: E501,SIM102
# ruff: noqa: E501,SIM102
import
math
import
torch
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
...
@@ -96,7 +94,7 @@ def _bmm_chunk_fwd_kernel(
...
@@ -96,7 +94,7 @@ def _bmm_chunk_fwd_kernel(
a_ptr
,
a_ptr
,
b_ptr
,
b_ptr
,
out_ptr
,
out_ptr
,
seq_idx
_ptr
,
cu_chunk_seqlens
_ptr
,
# Matrix dimensions
# Matrix dimensions
seqlen
,
seqlen
,
chunk_size
:
tl
.
constexpr
,
chunk_size
:
tl
.
constexpr
,
...
@@ -112,7 +110,6 @@ def _bmm_chunk_fwd_kernel(
...
@@ -112,7 +110,6 @@ def _bmm_chunk_fwd_kernel(
stride_out_head
:
tl
.
int64
,
stride_out_head
:
tl
.
int64
,
stride_outm
:
tl
.
int64
,
stride_outm
:
tl
.
int64
,
stride_outn
:
tl
.
constexpr
,
stride_outn
:
tl
.
constexpr
,
stride_seq_idx_seqlen
:
tl
.
constexpr
,
# Meta-parameters
# Meta-parameters
IS_CAUSAL
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
dot_dtype
:
tl
.
constexpr
,
dot_dtype
:
tl
.
constexpr
,
...
@@ -129,10 +126,12 @@ def _bmm_chunk_fwd_kernel(
...
@@ -129,10 +126,12 @@ def _bmm_chunk_fwd_kernel(
if
IS_CAUSAL
:
if
IS_CAUSAL
:
if
pid_n
*
BLOCK_SIZE_N
>=
(
pid_m
+
1
)
*
BLOCK_SIZE_M
:
if
pid_n
*
BLOCK_SIZE_N
>=
(
pid_m
+
1
)
*
BLOCK_SIZE_M
:
return
return
a_ptr
+=
pid_c
*
chunk_size
*
stride_a_seqlen
+
pid_h
*
stride_a_head
b_ptr
+=
pid_c
*
chunk_size
*
stride_b_seqlen
+
pid_h
*
stride_b_head
seq_idx_ptr
+=
pid_c
*
chunk_size
*
stride_seq_idx_seqlen
chunk_seqlen_start
=
tl
.
load
(
cu_chunk_seqlens_ptr
+
pid_c
)
chunk_seqlen_end
=
tl
.
load
(
cu_chunk_seqlens_ptr
+
pid_c
+
1
)
a_ptr
+=
chunk_seqlen_start
*
stride_a_seqlen
+
pid_h
*
stride_a_head
b_ptr
+=
chunk_seqlen_start
*
stride_b_seqlen
+
pid_h
*
stride_b_head
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
...
@@ -141,7 +140,7 @@ def _bmm_chunk_fwd_kernel(
...
@@ -141,7 +140,7 @@ def _bmm_chunk_fwd_kernel(
offs_k
[
None
,
:]
*
stride_ak
)
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
b_ptr
+
(
offs_k
[:,
None
]
*
stride_bk
+
b_ptrs
=
b_ptr
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_n
[
None
,
:]
*
stride_b_seqlen
)
offs_n
[
None
,
:]
*
stride_b_seqlen
)
chunk_size_limit
=
min
(
chunk_s
ize
,
seqlen
-
pid_c
*
chunk_size
)
chunk_size_limit
=
chunk_s
eqlen_end
-
chunk_seqlen_start
acc
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
...
@@ -162,16 +161,6 @@ def _bmm_chunk_fwd_kernel(
...
@@ -162,16 +161,6 @@ def _bmm_chunk_fwd_kernel(
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
# Zero out the results that are not from the same request
# in the varlen batch
seq_idx_m
=
tl
.
load
(
seq_idx_ptr
+
offs_m
*
stride_seq_idx_seqlen
,
mask
=
offs_m
<
chunk_size_limit
,
other
=-
1
)
seq_idx_n
=
tl
.
load
(
seq_idx_ptr
+
offs_n
*
stride_seq_idx_seqlen
,
mask
=
offs_n
<
chunk_size_limit
,
other
=-
2
)
acc
=
tl
.
where
(
seq_idx_m
[:,
None
]
==
seq_idx_n
[
None
,
:],
acc
,
0.0
)
out
=
acc
.
to
(
out_ptr
.
dtype
.
element_ty
)
out
=
acc
.
to
(
out_ptr
.
dtype
.
element_ty
)
out_ptr
+=
pid_c
*
stride_out_chunk
+
pid_h
*
stride_out_head
out_ptr
+=
pid_c
*
stride_out_chunk
+
pid_h
*
stride_out_head
out_ptrs
=
out_ptr
+
(
stride_outm
*
offs_m
[:,
None
]
+
out_ptrs
=
out_ptr
+
(
stride_outm
*
offs_m
[:,
None
]
+
...
@@ -182,12 +171,18 @@ def _bmm_chunk_fwd_kernel(
...
@@ -182,12 +171,18 @@ def _bmm_chunk_fwd_kernel(
(
offs_n
[
None
,
:]
<
chunk_size
))
(
offs_n
[
None
,
:]
<
chunk_size
))
def
_bmm_chunk_fwd
(
a
,
b
,
chunk_size
,
seq_idx
,
causal
=
False
,
output_dtype
=
None
):
def
_bmm_chunk_fwd
(
a
,
b
,
chunk_size
,
cu_chunk_seqlens
,
causal
=
False
,
output_dtype
=
None
):
"""
"""
Argument:
Argument:
a: (seqlen, ngroups, k)
a: (seqlen, ngroups, k)
b: (seqlen, ngroups, k)
b: (seqlen, ngroups, k)
seq_idx: (seqlen,). out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
chunk_size: int
cu_chunk_seq_lens: (nchunks+1,)
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
guaranteed to be correct.
guaranteed to be correct.
Return:
Return:
...
@@ -195,14 +190,12 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
...
@@ -195,14 +190,12 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
"""
"""
seqlen
,
ngroups
,
k
=
a
.
shape
seqlen
,
ngroups
,
k
=
a
.
shape
assert
b
.
shape
==
a
.
shape
assert
b
.
shape
==
a
.
shape
assert
seq_idx
is
not
None
assert
seq_idx
.
shape
==
(
seqlen
,
)
if
a
.
stride
(
-
1
)
!=
1
and
a
.
stride
(
0
)
!=
1
:
if
a
.
stride
(
-
1
)
!=
1
and
a
.
stride
(
0
)
!=
1
:
a
=
a
.
contiguous
()
a
=
a
.
contiguous
()
if
b
.
stride
(
-
1
)
!=
1
and
b
.
stride
(
0
)
!=
1
:
if
b
.
stride
(
-
1
)
!=
1
and
b
.
stride
(
0
)
!=
1
:
b
=
b
.
contiguous
()
b
=
b
.
contiguous
()
nchunks
=
math
.
ceil
(
seqlen
/
chunk_size
)
nchunks
=
len
(
cu_chunk_seqlens
)
-
1
# Allocates output.
# Allocates output.
out_dtype
=
a
.
dtype
if
output_dtype
is
None
else
output_dtype
out_dtype
=
a
.
dtype
if
output_dtype
is
None
else
output_dtype
out
=
torch
.
empty
((
nchunks
,
ngroups
,
chunk_size
,
chunk_size
),
out
=
torch
.
empty
((
nchunks
,
ngroups
,
chunk_size
,
chunk_size
),
...
@@ -220,7 +213,7 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
...
@@ -220,7 +213,7 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
a_ptr
=
a
,
a_ptr
=
a
,
b_ptr
=
b
,
b_ptr
=
b
,
out_ptr
=
out
,
out_ptr
=
out
,
seq_idx_ptr
=
seq_idx
,
cu_chunk_seqlens_ptr
=
cu_chunk_seqlens
,
seqlen
=
seqlen
,
seqlen
=
seqlen
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
K
=
k
,
K
=
k
,
...
@@ -235,7 +228,6 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
...
@@ -235,7 +228,6 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
stride_out_head
=
out
.
stride
(
1
),
stride_out_head
=
out
.
stride
(
1
),
stride_outm
=
out
.
stride
(
-
2
),
stride_outm
=
out
.
stride
(
-
2
),
stride_outn
=
out
.
stride
(
-
1
),
stride_outn
=
out
.
stride
(
-
1
),
stride_seq_idx_seqlen
=
seq_idx
.
stride
(
0
),
IS_CAUSAL
=
causal
,
IS_CAUSAL
=
causal
,
dot_dtype
=
dot_dtype
,
dot_dtype
=
dot_dtype
,
)
)
...
...
vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
View file @
fea3e476
...
@@ -120,9 +120,7 @@ def _chunk_scan_fwd_kernel(
...
@@ -120,9 +120,7 @@ def _chunk_scan_fwd_kernel(
states_ptr
,
states_ptr
,
D_ptr
,
D_ptr
,
initstates_ptr
,
initstates_ptr
,
chunk_indices_ptr
,
cu_chunk_seqlens_ptr
,
chunk_offsets_ptr
,
chunk_meta_num
,
# Matrix dimensions
# Matrix dimensions
chunk_size
:
tl
.
constexpr
,
chunk_size
:
tl
.
constexpr
,
hdim
:
tl
.
constexpr
,
hdim
:
tl
.
constexpr
,
...
@@ -149,7 +147,7 @@ def _chunk_scan_fwd_kernel(
...
@@ -149,7 +147,7 @@ def _chunk_scan_fwd_kernel(
stride_dA_cs_chunk
:
tl
.
int64
,
stride_dA_cs_chunk
:
tl
.
int64
,
stride_dA_cs_head
:
tl
.
int64
,
stride_dA_cs_head
:
tl
.
int64
,
stride_dA_cs_csize
:
tl
.
constexpr
,
stride_dA_cs_csize
:
tl
.
constexpr
,
stride_seq_idx_
seqlen
:
tl
.
constexpr
,
stride_seq_idx_
chunk
:
tl
.
constexpr
,
stride_C_seqlen
:
tl
.
int64
,
stride_C_seqlen
:
tl
.
int64
,
stride_C_head
:
tl
.
int64
,
stride_C_head
:
tl
.
int64
,
stride_C_dstate
:
tl
.
constexpr
,
stride_C_dstate
:
tl
.
constexpr
,
...
@@ -175,158 +173,95 @@ def _chunk_scan_fwd_kernel(
...
@@ -175,158 +173,95 @@ def _chunk_scan_fwd_kernel(
HAS_INITSTATES
:
tl
.
constexpr
,
HAS_INITSTATES
:
tl
.
constexpr
,
):
):
pid_c
=
tl
.
program_id
(
axis
=
1
).
to
(
tl
.
int64
)
pid_c
=
tl
.
program_id
(
axis
=
1
).
to
(
tl
.
int64
)
if
not
HAS_INITSTATES
:
c_idx
=
pid_c
c_off
=
0
else
:
c_idx
=
tl
.
load
(
chunk_indices_ptr
+
pid_c
,
mask
=
pid_c
>
-
1
,
other
=
0
)
c_off
=
tl
.
load
(
chunk_offsets_ptr
+
pid_c
,
mask
=
pid_c
>
-
1
,
other
=
0
)
pid_h
=
tl
.
program_id
(
axis
=
2
)
pid_h
=
tl
.
program_id
(
axis
=
2
)
num_pid_n
=
tl
.
cdiv
(
hdim
,
BLOCK_SIZE_N
)
num_pid_n
=
tl
.
cdiv
(
hdim
,
BLOCK_SIZE_N
)
pid_m
=
tl
.
program_id
(
axis
=
0
)
//
num_pid_n
pid_m
=
tl
.
program_id
(
axis
=
0
)
//
num_pid_n
pid_n
=
tl
.
program_id
(
axis
=
0
)
%
num_pid_n
pid_n
=
tl
.
program_id
(
axis
=
0
)
%
num_pid_n
cb_ptr
+=
c_idx
*
stride_cb_chunk
+
(
pid_h
//
cb_ptr
+=
pid_c
*
stride_cb_chunk
+
(
pid_h
//
nheads_ngroups_ratio
)
*
stride_cb_head
nheads_ngroups_ratio
)
*
stride_cb_head
x_ptr
+=
c_idx
*
chunk_size
*
stride_x_seqlen
+
pid_h
*
stride_x_head
chunk_seqlen_start
=
tl
.
load
(
cu_chunk_seqlens_ptr
+
pid_c
)
dt_ptr
+=
c_idx
*
stride_dt_chunk
+
pid_h
*
stride_dt_head
chunk_seqlen_end
=
tl
.
load
(
cu_chunk_seqlens_ptr
+
pid_c
+
1
)
dA_cumsum_ptr
+=
c_idx
*
stride_dA_cs_chunk
+
pid_h
*
stride_dA_cs_head
x_ptr
+=
chunk_seqlen_start
*
stride_x_seqlen
+
pid_h
*
stride_x_head
C_ptr
+=
c_idx
*
chunk_size
*
stride_C_seqlen
+
(
dt_ptr
+=
pid_c
*
stride_dt_chunk
+
pid_h
*
stride_dt_head
dA_cumsum_ptr
+=
pid_c
*
stride_dA_cs_chunk
+
pid_h
*
stride_dA_cs_head
C_ptr
+=
chunk_seqlen_start
*
stride_C_seqlen
+
(
pid_h
//
nheads_ngroups_ratio
)
*
stride_C_head
pid_h
//
nheads_ngroups_ratio
)
*
stride_C_head
# M-block offsets and prev states
# M-block offsets and prev states
# - logic in next block may override these if there is an active offset
# - logic in next block may override these if there is an active offset
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
c_off
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
prev_states_ptr
=
states_ptr
+
c_idx
*
stride_states_chunk
+
pid_h
*
stride_states_head
prev_states_hdim
=
stride_states_hdim
prev_states_dstate
=
stride_states_dstate
chunk_size_limit
=
min
(
chunk_size
,
seqlen
-
c_idx
*
chunk_size
)
seq_idx_ptr
+=
c_idx
*
chunk_size
*
stride_seq_idx_seqlen
seq_idx_ptr
+=
pid_c
*
stride_seq_idx_chunk
# - we only need seq_idx_prev to be aligned to chunk boundary
seq_idx
=
tl
.
load
(
seq_idx_ptr
)
seq_idx_prev
=
tl
.
load
(
seq_idx_ptr
-
stride_seq_idx_seqlen
,
seq_idx_prev
=
tl
.
load
(
seq_idx_ptr
-
stride_seq_idx_chunk
,
mask
=
c_idx
>=
1
,
mask
=
pid_c
>=
1
,
other
=
0
)
other
=-
1
)
if
HAS_INITSTATES
:
# if there are init states, we only need seq_idx_m to point
# what is the current seq_idx
# get current seq idx
if
(
pid_m
*
BLOCK_SIZE_M
+
c_off
)
<
chunk_size_limit
:
seq_idx_m
=
tl
.
load
(
seq_idx_ptr
+
(
pid_m
*
BLOCK_SIZE_M
+
c_off
)
*
stride_seq_idx_seqlen
,
)
# - recall that in ssd_state_passing, for the case c_off == 0
# i.e., the very first sequence, we made states_ptr hold its initial state
# so this edge case is taken care of
if
((
c_off
==
0
)
and
(
seq_idx_prev
!=
seq_idx_m
)
# if a seq is changed exactly on boundary
or
(
c_off
>
0
)
# implies a new example (pseudo chunk)
):
# - replace prev_states_ptr with init_states
if
HAS_INITSTATES
and
(
seq_idx
!=
seq_idx_prev
):
prev_states_ptr
=
initstates_ptr
+
seq_idx
_m
*
stride_init_states_batch
+
pid_h
*
stride_init_states_head
prev_states_ptr
=
initstates_ptr
+
seq_idx
*
stride_init_states_batch
+
pid_h
*
stride_init_states_head
prev_states_hdim
=
stride_init_states_hdim
# override strides
prev_states_hdim
=
stride_init_states_hdim
prev_states_dstate
=
stride_init_states_dstate
prev_states_dstate
=
stride_init_states_dstate
else
:
prev_states_ptr
=
states_ptr
+
(
pid_c
-
1
)
*
stride_states_chunk
+
pid_h
*
stride_states_head
prev_states_hdim
=
stride_states_hdim
prev_states_dstate
=
stride_states_dstate
chunk_size_limit
=
chunk_seqlen_end
-
chunk_seqlen_start
offs_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
dA_cs_m
=
tl
.
load
(
dA_cumsum_ptr
+
offs_m
*
stride_dA_cs_csize
,
dA_cs_m
=
tl
.
load
(
dA_cumsum_ptr
+
offs_m
*
stride_dA_cs_csize
,
mask
=
offs_m
<
chunk_size
,
mask
=
offs_m
<
chunk_size
,
other
=
0.0
).
to
(
tl
.
float32
)
other
=
0.0
).
to
(
tl
.
float32
)
# - handle chunk state limit
if
HAS_INITSTATES
:
# have to split this if otherwise compilation will have problems
dA_cs_m_boundary
=
0.0
# get the c_idx for the next (logica) chunk
c_idx_n
=
tl
.
load
(
chunk_indices_ptr
+
(
pid_c
+
1
),
mask
=
pid_c
>
-
1
and
(
pid_c
+
1
)
<
chunk_meta_num
,
other
=-
1
# to trigger different chunk
)
# - there are things to consider
# A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct
# contribution of past states
# B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to
# encroach into the next sequence, where c_off_n is the offset of the next
# (logical) chunk.
# An equivalent check for B is c_idx == c_idx_n, where there is repetition in
# (logical) chunk indices.
if
(
c_idx
==
c_idx_n
)
or
c_off
>
0
:
# get the next offset
c_off_n
=
tl
.
load
(
chunk_offsets_ptr
+
(
pid_c
+
1
),
mask
=
pid_c
>
-
1
and
(
pid_c
+
1
)
<
chunk_meta_num
,
other
=
chunk_size
)
# in this case, adjust down the chunk_size_limit
if
c_idx
==
c_idx_n
:
chunk_size_limit
=
min
(
c_off_n
,
chunk_size_limit
)
# get the cs at the offset boundary
# - c_off == 0 is a passthrough
# - We need dA_cs at the boundary, defined by c_off - no need
# to increase pointer by pid_m (it is a constant offset,
# i.e. the same for all blocks)
dA_cs_m_boundary
=
tl
.
load
(
dA_cumsum_ptr
+
(
c_off
-
1
)
*
stride_dA_cs_csize
,
mask
=
(((
c_off
-
1
)
>
-
1
)
and
((
c_off
)
<
chunk_size
)),
other
=
0.0
).
to
(
tl
.
float32
)
else
:
# - handle seq idx when HAS_INITSTATES==False
seq_idx_m
=
tl
.
load
(
seq_idx_ptr
+
offs_m
*
stride_seq_idx_seqlen
,
mask
=
offs_m
<
chunk_size_limit
,
other
=-
1
)
acc
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
# Without the if (pid_c > -1), with Triton 2.1.0, I get
offs_out_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
# Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed.
offs_out_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
# With Triton 2.2.0, this works
if
IS_TRITON_22
or
c_idx
>
-
1
:
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
offs_k_dstate
=
tl
.
arange
(
offs_k_dstate
=
tl
.
arange
(
0
,
BLOCK_SIZE_DSTATE
if
BLOCK_SIZE_DSTATE
<=
128
else
BLOCK_SIZE_K
)
0
,
BLOCK_SIZE_DSTATE
if
BLOCK_SIZE_DSTATE
<=
128
else
BLOCK_SIZE_K
)
C_ptrs
=
C_ptr
+
(
offs_m
[:,
None
]
*
stride_C_seqlen
+
C_ptrs
=
C_ptr
+
(
offs_m
[:,
None
]
*
stride_C_seqlen
+
offs_k_dstate
[
None
,
:]
*
stride_C_dstate
)
offs_k_dstate
[
None
,
:]
*
stride_C_dstate
)
prev_states_ptrs
=
prev_states_ptr
+
(
scale_m
=
tl
.
exp
(
dA_cs_m
)
offs_n
[
None
,
:]
*
prev_states_hdim
+
offs_k_dstate
[:,
None
]
*
prev_states_dstate
)
if
not
HAS_INITSTATES
:
# - this is for continuous batching where there is no init states
scale_m
=
tl
.
where
(
seq_idx_m
==
seq_idx_prev
,
tl
.
exp
(
dA_cs_m
),
0.0
)
else
:
# - if there is initstates, we will rely on prev_states, no zeroing
# required.
scale_m
=
tl
.
exp
(
dA_cs_m
-
dA_cs_m_boundary
)
if
BLOCK_SIZE_DSTATE
<=
128
:
if
BLOCK_SIZE_DSTATE
<=
128
:
C
=
tl
.
load
(
C_ptrs
,
C
=
tl
.
load
(
C_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
chunk_size_limit
)
&
mask
=
(
offs_m
[:,
None
]
<
chunk_size_limit
)
&
(
offs_k_dstate
[
None
,
:]
<
dstate
),
(
offs_k_dstate
[
None
,
:]
<
dstate
),
other
=
0.0
)
other
=
0.0
)
if
not
HAS_INITSTATES
and
(
seq_idx
!=
seq_idx_prev
):
# if no init states AND starting a new sequence, we need zeros
prev_states
=
tl
.
zeros
((
BLOCK_SIZE_DSTATE
,
BLOCK_SIZE_N
),
dtype
=
C_ptr
.
dtype
.
element_ty
)
else
:
# otherwise read the previous state
prev_states_ptrs
=
prev_states_ptr
\
+
offs_n
[
None
,
:]
*
prev_states_hdim
\
+
offs_k_dstate
[:,
None
]
*
prev_states_dstate
prev_states
=
tl
.
load
(
prev_states_ptrs
,
prev_states
=
tl
.
load
(
prev_states_ptrs
,
mask
=
(
offs_k_dstate
[:,
None
]
<
dstate
)
&
mask
=
(
offs_k_dstate
[:,
None
]
<
dstate
)
&
(
offs_n
[
None
,
:]
<
hdim
),
(
offs_n
[
None
,
:]
<
hdim
),
other
=
0.0
)
other
=
0.0
)
prev_states
=
prev_states
.
to
(
C_ptr
.
dtype
.
element_ty
)
prev_states
=
prev_states
.
to
(
C_ptr
.
dtype
.
element_ty
)
acc
=
tl
.
dot
(
C
,
prev_states
)
*
scale_m
[:,
None
]
acc
=
tl
.
dot
(
C
,
prev_states
)
*
scale_m
[:,
None
]
else
:
else
:
prev_states_ptrs
=
prev_states_ptr
\
+
offs_n
[
None
,
:]
*
prev_states_hdim
\
+
offs_k_dstate
[:,
None
]
*
prev_states_dstate
for
k
in
range
(
0
,
dstate
,
BLOCK_SIZE_K
):
for
k
in
range
(
0
,
dstate
,
BLOCK_SIZE_K
):
C
=
tl
.
load
(
C_ptrs
,
C
=
tl
.
load
(
C_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
chunk_size_limit
)
&
mask
=
(
offs_m
[:,
None
]
<
chunk_size_limit
)
&
(
offs_k_dstate
[
None
,
:]
<
dstate
-
k
),
(
offs_k_dstate
[
None
,
:]
<
dstate
-
k
),
other
=
0.0
)
other
=
0.0
)
# C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)
if
not
HAS_INITSTATES
and
(
seq_idx
!=
seq_idx_prev
):
prev_states
=
tl
.
zeros
((
BLOCK_SIZE_DSTATE
,
BLOCK_SIZE_K
),
dtype
=
C_ptr
.
dtype
.
element_ty
)
else
:
prev_states
=
tl
.
load
(
prev_states
=
tl
.
load
(
prev_states_ptrs
,
prev_states_ptrs
,
mask
=
(
offs_k_dstate
[:,
None
]
<
dstate
-
k
)
&
mask
=
(
offs_k_dstate
[:,
None
]
<
dstate
-
k
)
&
...
@@ -338,7 +273,7 @@ def _chunk_scan_fwd_kernel(
...
@@ -338,7 +273,7 @@ def _chunk_scan_fwd_kernel(
prev_states_ptrs
+=
BLOCK_SIZE_K
prev_states_ptrs
+=
BLOCK_SIZE_K
acc
*=
scale_m
[:,
None
]
acc
*=
scale_m
[:,
None
]
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
+
c_off
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
cb_ptrs
=
cb_ptr
+
(
offs_m
[:,
None
]
*
stride_cb_csize_m
+
cb_ptrs
=
cb_ptr
+
(
offs_m
[:,
None
]
*
stride_cb_csize_m
+
offs_k
[
None
,
:]
*
stride_cb_csize_k
)
offs_k
[
None
,
:]
*
stride_cb_csize_k
)
x_ptrs
=
x_ptr
+
(
offs_k
[:,
None
]
*
stride_x_seqlen
+
x_ptrs
=
x_ptr
+
(
offs_k
[:,
None
]
*
stride_x_seqlen
+
...
@@ -375,7 +310,7 @@ def _chunk_scan_fwd_kernel(
...
@@ -375,7 +310,7 @@ def _chunk_scan_fwd_kernel(
dt_ptrs
+=
BLOCK_SIZE_K
*
stride_dt_csize
dt_ptrs
+=
BLOCK_SIZE_K
*
stride_dt_csize
dA_cumsum_ptrs
+=
BLOCK_SIZE_K
*
stride_dA_cs_csize
dA_cumsum_ptrs
+=
BLOCK_SIZE_K
*
stride_dA_cs_csize
offs_out_m
=
pid_m
*
BLOCK_SIZE_M
+
c_off
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_out_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_out_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_out_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
if
HAS_D
:
if
HAS_D
:
...
@@ -393,7 +328,7 @@ def _chunk_scan_fwd_kernel(
...
@@ -393,7 +328,7 @@ def _chunk_scan_fwd_kernel(
acc
+=
x_residual
*
D
acc
+=
x_residual
*
D
if
HAS_Z
:
if
HAS_Z
:
z_ptr
+=
c
_idx
*
chunk_size
*
stride_z_seqlen
+
pid_h
*
stride_z_head
z_ptr
+=
c
hunk_seqlen_start
*
stride_z_seqlen
+
pid_h
*
stride_z_head
z_ptrs
=
z_ptr
+
(
stride_z_seqlen
*
offs_out_m
[:,
None
]
+
z_ptrs
=
z_ptr
+
(
stride_z_seqlen
*
offs_out_m
[:,
None
]
+
stride_z_hdim
*
offs_out_n
[
None
,
:])
stride_z_hdim
*
offs_out_n
[
None
,
:])
z
=
tl
.
load
(
z_ptrs
,
z
=
tl
.
load
(
z_ptrs
,
...
@@ -402,7 +337,7 @@ def _chunk_scan_fwd_kernel(
...
@@ -402,7 +337,7 @@ def _chunk_scan_fwd_kernel(
other
=
0.0
).
to
(
tl
.
float32
)
other
=
0.0
).
to
(
tl
.
float32
)
acc
*=
z
*
tl
.
sigmoid
(
z
)
acc
*=
z
*
tl
.
sigmoid
(
z
)
out_ptr
+=
c
_idx
*
chunk_size
*
stride_out_seqlen
+
pid_h
*
stride_out_head
out_ptr
+=
c
hunk_seqlen_start
*
stride_out_seqlen
+
pid_h
*
stride_out_head
out_ptrs
=
out_ptr
+
(
stride_out_seqlen
*
offs_out_m
[:,
None
]
+
out_ptrs
=
out_ptr
+
(
stride_out_seqlen
*
offs_out_m
[:,
None
]
+
offs_out_n
[
None
,
:]
*
stride_out_hdim
)
offs_out_n
[
None
,
:]
*
stride_out_hdim
)
tl
.
store
(
out_ptrs
,
tl
.
store
(
out_ptrs
,
...
@@ -418,12 +353,11 @@ def _chunk_scan_fwd(
...
@@ -418,12 +353,11 @@ def _chunk_scan_fwd(
dA_cumsum
,
dA_cumsum
,
C
,
C
,
states
,
states
,
cu_chunk_seqlens
,
out
,
out
,
seq_idx
,
seq_idx
,
D
=
None
,
D
=
None
,
z
=
None
,
z
=
None
,
chunk_indices
=
None
,
chunk_offsets
=
None
,
initial_states
=
None
,
initial_states
=
None
,
):
):
assert
seq_idx
is
not
None
,
"this implementation requires seq_idx"
assert
seq_idx
is
not
None
,
"this implementation requires seq_idx"
...
@@ -441,20 +375,10 @@ def _chunk_scan_fwd(
...
@@ -441,20 +375,10 @@ def _chunk_scan_fwd(
assert
dt
.
shape
==
(
nheads
,
nchunks
,
chunk_size
)
assert
dt
.
shape
==
(
nheads
,
nchunks
,
chunk_size
)
assert
dA_cumsum
.
shape
==
(
nheads
,
nchunks
,
chunk_size
)
assert
dA_cumsum
.
shape
==
(
nheads
,
nchunks
,
chunk_size
)
assert
states
.
shape
==
(
nchunks
,
nheads
,
headdim
,
dstate
)
assert
states
.
shape
==
(
nchunks
,
nheads
,
headdim
,
dstate
)
assert
seq_idx
.
shape
==
(
seqlen
,
)
assert
seq_idx
.
shape
==
(
nchunks
,
)
if
initial_states
is
not
None
:
# with initial states, we need to take care of how
# seq_idx crosses the boundaries
assert
chunk_indices
is
not
None
and
chunk_offsets
is
not
None
,
\
"chunk_indices and chunk_offsets should have been set"
else
:
chunk_indices
,
chunk_offsets
=
None
,
None
grid
=
lambda
META
:
(
grid
=
lambda
META
:
(
triton
.
cdiv
(
chunk_size
,
META
[
'BLOCK_SIZE_M'
])
*
triton
triton
.
cdiv
(
chunk_size
,
META
[
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
.
cdiv
(
headdim
,
META
[
'BLOCK_SIZE_N'
]),
nchunks
,
nheads
)
headdim
,
META
[
'BLOCK_SIZE_N'
]),
nchunks
if
chunk_offsets
is
None
else
len
(
chunk_offsets
),
nheads
)
z_strides
=
((
z
.
stride
(
0
),
z
.
stride
(
1
),
z
.
stride
(
2
))
if
z
is
not
None
else
z_strides
=
((
z
.
stride
(
0
),
z
.
stride
(
1
),
z
.
stride
(
2
))
if
z
is
not
None
else
(
0
,
0
,
0
))
(
0
,
0
,
0
))
...
@@ -476,9 +400,7 @@ def _chunk_scan_fwd(
...
@@ -476,9 +400,7 @@ def _chunk_scan_fwd(
states_ptr
=
states
,
states_ptr
=
states
,
D_ptr
=
D
,
D_ptr
=
D
,
initstates_ptr
=
initial_states
,
initstates_ptr
=
initial_states
,
chunk_indices_ptr
=
chunk_indices
,
cu_chunk_seqlens_ptr
=
cu_chunk_seqlens
,
chunk_offsets_ptr
=
chunk_offsets
,
chunk_meta_num
=
len
(
chunk_indices
)
if
chunk_indices
is
not
None
else
0
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
hdim
=
headdim
,
hdim
=
headdim
,
dstate
=
dstate
,
dstate
=
dstate
,
...
@@ -503,7 +425,7 @@ def _chunk_scan_fwd(
...
@@ -503,7 +425,7 @@ def _chunk_scan_fwd(
stride_dA_cs_chunk
=
dA_cumsum
.
stride
(
1
),
stride_dA_cs_chunk
=
dA_cumsum
.
stride
(
1
),
stride_dA_cs_head
=
dA_cumsum
.
stride
(
0
),
stride_dA_cs_head
=
dA_cumsum
.
stride
(
0
),
stride_dA_cs_csize
=
dA_cumsum
.
stride
(
2
),
stride_dA_cs_csize
=
dA_cumsum
.
stride
(
2
),
stride_seq_idx_
seqlen
=
seq_idx
.
stride
(
0
),
stride_seq_idx_
chunk
=
seq_idx
.
stride
(
0
),
stride_C_seqlen
=
C
.
stride
(
0
),
stride_C_seqlen
=
C
.
stride
(
0
),
stride_C_head
=
C
.
stride
(
1
),
stride_C_head
=
C
.
stride
(
1
),
stride_C_dstate
=
C
.
stride
(
2
),
stride_C_dstate
=
C
.
stride
(
2
),
...
...
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
View file @
fea3e476
...
@@ -6,8 +6,6 @@
...
@@ -6,8 +6,6 @@
# ruff: noqa: E501
# ruff: noqa: E501
import
math
import
torch
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
...
@@ -34,6 +32,7 @@ def _chunk_cumsum_fwd_kernel(
...
@@ -34,6 +32,7 @@ def _chunk_cumsum_fwd_kernel(
dt_bias_ptr
,
dt_bias_ptr
,
dt_out_ptr
,
dt_out_ptr
,
dA_cumsum_ptr
,
dA_cumsum_ptr
,
cu_chunk_seqlens_ptr
,
# Matrix dimension
# Matrix dimension
seqlen
,
seqlen
,
nheads
:
tl
.
constexpr
,
nheads
:
tl
.
constexpr
,
...
@@ -61,7 +60,11 @@ def _chunk_cumsum_fwd_kernel(
...
@@ -61,7 +60,11 @@ def _chunk_cumsum_fwd_kernel(
# https://github.com/triton-lang/triton/issues/1058
# https://github.com/triton-lang/triton/issues/1058
pid_c
=
tl
.
program_id
(
axis
=
0
).
to
(
tl
.
int64
)
pid_c
=
tl
.
program_id
(
axis
=
0
).
to
(
tl
.
int64
)
pid_h
=
tl
.
program_id
(
axis
=
1
)
pid_h
=
tl
.
program_id
(
axis
=
1
)
dt_ptr
+=
pid_c
*
chunk_size
*
stride_dt_seqlen
chunk_seqlen_start
=
tl
.
load
(
cu_chunk_seqlens_ptr
+
pid_c
)
chunk_seqlen_end
=
tl
.
load
(
cu_chunk_seqlens_ptr
+
pid_c
+
1
)
dt_ptr
+=
chunk_seqlen_start
*
stride_dt_seqlen
dt_out_ptr
+=
pid_c
*
stride_dt_out_chunk
dt_out_ptr
+=
pid_c
*
stride_dt_out_chunk
dA_cumsum_ptr
+=
pid_c
*
stride_dA_cs_chunk
dA_cumsum_ptr
+=
pid_c
*
stride_dA_cs_chunk
...
@@ -74,7 +77,7 @@ def _chunk_cumsum_fwd_kernel(
...
@@ -74,7 +77,7 @@ def _chunk_cumsum_fwd_kernel(
offs_c
[
None
,
:]
*
stride_dt_out_csize
)
offs_c
[
None
,
:]
*
stride_dt_out_csize
)
dA_cs_ptrs
=
dA_cumsum_ptr
+
(
offs_h
[:,
None
]
*
stride_dA_cs_head
+
dA_cs_ptrs
=
dA_cumsum_ptr
+
(
offs_h
[:,
None
]
*
stride_dA_cs_head
+
offs_c
[
None
,
:]
*
stride_dA_cs_csize
)
offs_c
[
None
,
:]
*
stride_dA_cs_csize
)
chunk_size_limit
=
min
(
chunk_s
ize
,
seqlen
-
pid_c
*
chunk_size
)
chunk_size_limit
=
chunk_s
eqlen_end
-
chunk_seqlen_start
dt
=
tl
.
load
(
dt_ptrs
,
dt
=
tl
.
load
(
dt_ptrs
,
mask
=
(
offs_h
[:,
None
]
<
nheads
)
&
mask
=
(
offs_h
[:,
None
]
<
nheads
)
&
...
@@ -188,7 +191,7 @@ def _chunk_state_fwd_kernel(
...
@@ -188,7 +191,7 @@ def _chunk_state_fwd_kernel(
states_ptr
,
states_ptr
,
dt_ptr
,
dt_ptr
,
dA_cumsum_ptr
,
dA_cumsum_ptr
,
seq_idx
_ptr
,
cu_chunk_seqlens
_ptr
,
# Matrix dimensions
# Matrix dimensions
hdim
:
tl
.
constexpr
,
hdim
:
tl
.
constexpr
,
dstate
:
tl
.
constexpr
,
dstate
:
tl
.
constexpr
,
...
@@ -212,7 +215,6 @@ def _chunk_state_fwd_kernel(
...
@@ -212,7 +215,6 @@ def _chunk_state_fwd_kernel(
stride_dA_cs_head
:
tl
.
int64
,
stride_dA_cs_head
:
tl
.
int64
,
stride_dA_cs_chunk
:
tl
.
int64
,
stride_dA_cs_chunk
:
tl
.
int64
,
stride_dA_cs_csize
:
tl
.
constexpr
,
stride_dA_cs_csize
:
tl
.
constexpr
,
stride_seq_idx_seqlen
:
tl
.
constexpr
,
# Meta-parameters
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
...
@@ -223,14 +225,14 @@ def _chunk_state_fwd_kernel(
...
@@ -223,14 +225,14 @@ def _chunk_state_fwd_kernel(
num_pid_n
=
tl
.
cdiv
(
dstate
,
BLOCK_SIZE_N
)
num_pid_n
=
tl
.
cdiv
(
dstate
,
BLOCK_SIZE_N
)
pid_m
=
tl
.
program_id
(
axis
=
0
)
//
num_pid_n
pid_m
=
tl
.
program_id
(
axis
=
0
)
//
num_pid_n
pid_n
=
tl
.
program_id
(
axis
=
0
)
%
num_pid_n
pid_n
=
tl
.
program_id
(
axis
=
0
)
%
num_pid_n
b_ptr
+=
pid_c
*
chunk_size
*
stride_b_seqlen
+
(
chunk_seqlen_start
=
tl
.
load
(
cu_chunk_seqlens_ptr
+
pid_c
)
chunk_seqlen_end
=
tl
.
load
(
cu_chunk_seqlens_ptr
+
pid_c
+
1
)
b_ptr
+=
chunk_seqlen_start
*
stride_b_seqlen
+
(
pid_h
//
nheads_ngroups_ratio
)
*
stride_b_head
pid_h
//
nheads_ngroups_ratio
)
*
stride_b_head
x_ptr
+=
pid_c
*
chunk_size
*
stride_x_seqlen
+
pid_h
*
stride_x_head
x_ptr
+=
chunk_seqlen_start
*
stride_x_seqlen
+
pid_h
*
stride_x_head
dt_ptr
+=
pid_c
*
stride_dt_chunk
+
pid_h
*
stride_dt_head
dt_ptr
+=
pid_c
*
stride_dt_chunk
+
pid_h
*
stride_dt_head
dA_cumsum_ptr
+=
pid_c
*
stride_dA_cs_chunk
+
pid_h
*
stride_dA_cs_head
dA_cumsum_ptr
+=
pid_c
*
stride_dA_cs_chunk
+
pid_h
*
stride_dA_cs_head
seq_idx_ptr
+=
pid_c
*
chunk_size
*
stride_seq_idx_seqlen
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
...
@@ -243,10 +245,7 @@ def _chunk_state_fwd_kernel(
...
@@ -243,10 +245,7 @@ def _chunk_state_fwd_kernel(
(
chunk_size
-
1
)
*
stride_dA_cs_csize
).
to
(
tl
.
float32
)
(
chunk_size
-
1
)
*
stride_dA_cs_csize
).
to
(
tl
.
float32
)
dA_cumsum_ptrs
=
dA_cumsum_ptr
+
offs_k
*
stride_dA_cs_csize
dA_cumsum_ptrs
=
dA_cumsum_ptr
+
offs_k
*
stride_dA_cs_csize
seq_idx_ptrs
=
seq_idx_ptr
+
offs_k
*
stride_seq_idx_seqlen
chunk_size_limit
=
chunk_seqlen_end
-
chunk_seqlen_start
chunk_size_limit
=
min
(
chunk_size
,
seqlen
-
pid_c
*
chunk_size
)
seq_idx_last
=
tl
.
load
(
seq_idx_ptr
+
(
chunk_size_limit
-
1
)
*
stride_seq_idx_seqlen
)
acc
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
chunk_size_limit
,
BLOCK_SIZE_K
):
for
k
in
range
(
0
,
chunk_size_limit
,
BLOCK_SIZE_K
):
...
@@ -261,15 +260,9 @@ def _chunk_state_fwd_kernel(
...
@@ -261,15 +260,9 @@ def _chunk_state_fwd_kernel(
dA_cs_k
=
tl
.
load
(
dA_cumsum_ptrs
,
dA_cs_k
=
tl
.
load
(
dA_cumsum_ptrs
,
mask
=
offs_k
<
chunk_size_limit
-
k
,
mask
=
offs_k
<
chunk_size_limit
-
k
,
other
=
0.0
).
to
(
tl
.
float32
)
other
=
0.0
).
to
(
tl
.
float32
)
seq_idx_k
=
tl
.
load
(
seq_idx_ptrs
,
mask
=
offs_k
<
chunk_size_limit
-
k
,
other
=-
1
)
dt_k
=
tl
.
load
(
dt_ptrs
,
mask
=
offs_k
<
chunk_size_limit
-
k
,
dt_k
=
tl
.
load
(
dt_ptrs
,
mask
=
offs_k
<
chunk_size_limit
-
k
,
other
=
0.0
).
to
(
tl
.
float32
)
other
=
0.0
).
to
(
tl
.
float32
)
scale
=
tl
.
exp
(
dA_cs_last
-
dA_cs_k
)
*
dt_k
scale
=
tl
.
where
(
seq_idx_k
==
seq_idx_last
,
tl
.
exp
(
dA_cs_last
-
dA_cs_k
)
*
dt_k
,
0.0
)
b
*=
scale
[:,
None
]
b
*=
scale
[:,
None
]
b
=
b
.
to
(
x_ptr
.
dtype
.
element_ty
)
b
=
b
.
to
(
x_ptr
.
dtype
.
element_ty
)
acc
+=
tl
.
dot
(
x
,
b
)
acc
+=
tl
.
dot
(
x
,
b
)
...
@@ -278,7 +271,6 @@ def _chunk_state_fwd_kernel(
...
@@ -278,7 +271,6 @@ def _chunk_state_fwd_kernel(
b_ptrs
+=
BLOCK_SIZE_K
*
stride_b_seqlen
b_ptrs
+=
BLOCK_SIZE_K
*
stride_b_seqlen
dt_ptrs
+=
BLOCK_SIZE_K
*
stride_dt_csize
dt_ptrs
+=
BLOCK_SIZE_K
*
stride_dt_csize
dA_cumsum_ptrs
+=
BLOCK_SIZE_K
*
stride_dA_cs_csize
dA_cumsum_ptrs
+=
BLOCK_SIZE_K
*
stride_dA_cs_csize
seq_idx_ptrs
+=
BLOCK_SIZE_K
*
stride_seq_idx_seqlen
states
=
acc
.
to
(
states_ptr
.
dtype
.
element_ty
)
states
=
acc
.
to
(
states_ptr
.
dtype
.
element_ty
)
...
@@ -534,6 +526,7 @@ def _chunk_state_varlen_kernel(
...
@@ -534,6 +526,7 @@ def _chunk_state_varlen_kernel(
def
_chunk_cumsum_fwd
(
dt
,
def
_chunk_cumsum_fwd
(
dt
,
A
,
A
,
chunk_size
,
chunk_size
,
cu_chunk_seqlens
,
dt_bias
=
None
,
dt_bias
=
None
,
dt_softplus
=
False
,
dt_softplus
=
False
,
dt_limit
=
(
0.0
,
float
(
"inf"
))):
dt_limit
=
(
0.0
,
float
(
"inf"
))):
...
@@ -541,7 +534,7 @@ def _chunk_cumsum_fwd(dt,
...
@@ -541,7 +534,7 @@ def _chunk_cumsum_fwd(dt,
assert
A
.
shape
==
(
nheads
,
)
assert
A
.
shape
==
(
nheads
,
)
if
dt_bias
is
not
None
:
if
dt_bias
is
not
None
:
assert
dt_bias
.
shape
==
(
nheads
,
)
assert
dt_bias
.
shape
==
(
nheads
,
)
nchunks
=
math
.
ceil
(
seqlen
/
chunk_size
)
nchunks
=
cu_chunk_seqlens
.
shape
[
0
]
-
1
dt_out
=
torch
.
empty
(
nheads
,
dt_out
=
torch
.
empty
(
nheads
,
nchunks
,
nchunks
,
chunk_size
,
chunk_size
,
...
@@ -561,6 +554,7 @@ def _chunk_cumsum_fwd(dt,
...
@@ -561,6 +554,7 @@ def _chunk_cumsum_fwd(dt,
dt_bias_ptr
=
dt_bias
,
dt_bias_ptr
=
dt_bias
,
dt_out_ptr
=
dt_out
,
dt_out_ptr
=
dt_out
,
dA_cumsum_ptr
=
dA_cumsum
,
dA_cumsum_ptr
=
dA_cumsum
,
cu_chunk_seqlens_ptr
=
cu_chunk_seqlens
,
seqlen
=
seqlen
,
seqlen
=
seqlen
,
nheads
=
nheads
,
nheads
=
nheads
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
...
@@ -588,7 +582,7 @@ def _chunk_state_fwd(B,
...
@@ -588,7 +582,7 @@ def _chunk_state_fwd(B,
x
,
x
,
dt
,
dt
,
dA_cumsum
,
dA_cumsum
,
seq_idx
=
None
,
cu_chunk_seqlens
,
states
=
None
,
states
=
None
,
states_in_fp32
=
True
):
states_in_fp32
=
True
):
seqlen
,
nheads
,
headdim
=
x
.
shape
seqlen
,
nheads
,
headdim
=
x
.
shape
...
@@ -599,9 +593,6 @@ def _chunk_state_fwd(B,
...
@@ -599,9 +593,6 @@ def _chunk_state_fwd(B,
assert
dt
.
shape
==
(
nheads
,
nchunks
,
chunk_size
)
assert
dt
.
shape
==
(
nheads
,
nchunks
,
chunk_size
)
assert
dA_cumsum
.
shape
==
dt
.
shape
assert
dA_cumsum
.
shape
==
dt
.
shape
assert
seq_idx
is
not
None
assert
seq_idx
.
shape
==
(
seqlen
,
)
if
states
is
not
None
:
if
states
is
not
None
:
assert
states
.
shape
==
(
nchunks
,
nheads
,
headdim
,
dstate
)
assert
states
.
shape
==
(
nchunks
,
nheads
,
headdim
,
dstate
)
else
:
else
:
...
@@ -619,7 +610,7 @@ def _chunk_state_fwd(B,
...
@@ -619,7 +610,7 @@ def _chunk_state_fwd(B,
states_ptr
=
states
,
states_ptr
=
states
,
dt_ptr
=
dt
,
dt_ptr
=
dt
,
dA_cumsum_ptr
=
dA_cumsum
,
dA_cumsum_ptr
=
dA_cumsum
,
seq_idx_ptr
=
seq_idx
,
cu_chunk_seqlens_ptr
=
cu_chunk_seqlens
,
hdim
=
headdim
,
hdim
=
headdim
,
dstate
=
dstate
,
dstate
=
dstate
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
...
@@ -641,7 +632,6 @@ def _chunk_state_fwd(B,
...
@@ -641,7 +632,6 @@ def _chunk_state_fwd(B,
stride_dA_cs_head
=
dA_cumsum
.
stride
(
0
),
stride_dA_cs_head
=
dA_cumsum
.
stride
(
0
),
stride_dA_cs_chunk
=
dA_cumsum
.
stride
(
1
),
stride_dA_cs_chunk
=
dA_cumsum
.
stride
(
1
),
stride_dA_cs_csize
=
dA_cumsum
.
stride
(
2
),
stride_dA_cs_csize
=
dA_cumsum
.
stride
(
2
),
stride_seq_idx_seqlen
=
seq_idx
.
stride
(
0
),
)
)
return
states
return
states
...
...
vllm/model_executor/layers/mamba/ops/ssd_combined.py
View file @
fea3e476
...
@@ -14,8 +14,7 @@ from vllm.triton_utils import triton
...
@@ -14,8 +14,7 @@ from vllm.triton_utils import triton
from
.ssd_bmm
import
_bmm_chunk_fwd
from
.ssd_bmm
import
_bmm_chunk_fwd
from
.ssd_chunk_scan
import
_chunk_scan_fwd
from
.ssd_chunk_scan
import
_chunk_scan_fwd
from
.ssd_chunk_state
import
(
_chunk_cumsum_fwd
,
_chunk_state_fwd
,
from
.ssd_chunk_state
import
_chunk_cumsum_fwd
,
_chunk_state_fwd
chunk_state_varlen
)
from
.ssd_state_passing
import
_state_passing_fwd
from
.ssd_state_passing
import
_state_passing_fwd
TRITON_22
=
version
.
parse
(
triton
.
__version__
)
>=
version
.
parse
(
'2.2.0'
)
TRITON_22
=
version
.
parse
(
triton
.
__version__
)
>=
version
.
parse
(
'2.2.0'
)
...
@@ -37,9 +36,9 @@ def _mamba_chunk_scan_combined_fwd(x,
...
@@ -37,9 +36,9 @@ def _mamba_chunk_scan_combined_fwd(x,
dt_bias
=
None
,
dt_bias
=
None
,
initial_states
=
None
,
initial_states
=
None
,
seq_idx
=
None
,
seq_idx
=
None
,
chunk_indices
=
None
,
chunk_offsets
=
None
,
cu_seqlens
=
None
,
cu_seqlens
=
None
,
cu_chunk_seqlens
=
None
,
last_chunk_indices
=
None
,
dt_softplus
=
False
,
dt_softplus
=
False
,
dt_limit
=
(
0.0
,
float
(
"inf"
)),
dt_limit
=
(
0.0
,
float
(
"inf"
)),
state_dtype
=
None
):
state_dtype
=
None
):
...
@@ -56,7 +55,7 @@ def _mamba_chunk_scan_combined_fwd(x,
...
@@ -56,7 +55,7 @@ def _mamba_chunk_scan_combined_fwd(x,
if
D
is
not
None
:
if
D
is
not
None
:
assert
D
.
shape
==
(
nheads
,
headdim
)
or
D
.
shape
==
(
nheads
,
)
assert
D
.
shape
==
(
nheads
,
headdim
)
or
D
.
shape
==
(
nheads
,
)
if
seq_idx
is
not
None
:
if
seq_idx
is
not
None
:
assert
seq_idx
.
shape
==
(
seqlen
,
)
assert
seq_idx
.
shape
==
(
cu_chunk_seqlens
.
shape
[
0
]
-
1
,
)
if
B
.
stride
(
-
1
)
!=
1
:
if
B
.
stride
(
-
1
)
!=
1
:
B
=
B
.
contiguous
()
B
=
B
.
contiguous
()
if
C
.
stride
(
-
1
)
!=
1
:
if
C
.
stride
(
-
1
)
!=
1
:
...
@@ -89,6 +88,7 @@ def _mamba_chunk_scan_combined_fwd(x,
...
@@ -89,6 +88,7 @@ def _mamba_chunk_scan_combined_fwd(x,
dA_cumsum
,
dt
=
_chunk_cumsum_fwd
(
dt
,
dA_cumsum
,
dt
=
_chunk_cumsum_fwd
(
dt
,
A
,
A
,
chunk_size
,
chunk_size
,
cu_chunk_seqlens
,
dt_bias
=
dt_bias
,
dt_bias
=
dt_bias
,
dt_softplus
=
dt_softplus
,
dt_softplus
=
dt_softplus
,
dt_limit
=
dt_limit
)
dt_limit
=
dt_limit
)
...
@@ -99,36 +99,31 @@ def _mamba_chunk_scan_combined_fwd(x,
...
@@ -99,36 +99,31 @@ def _mamba_chunk_scan_combined_fwd(x,
x
,
x
,
dt
,
dt
,
dA_cumsum
,
dA_cumsum
,
seq_idx
=
seq_idx
,
cu_chunk_seqlens
,
states_in_fp32
=
True
)
states_in_fp32
=
True
)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
# (middle term of factorization of off-diag blocks; A terms)
# - for handling chunked prefill, this requires i) initial_states
# - for handling chunked prefill, this requires i) initial_states
and
# ii) seq_idx
iii) is_cont_batched and (iv) chunk_offsets
to be all specified.
# ii) seq_idx to be all specified.
# - When a new seq_idx is detected, we will stop passing the prev_state
# - When a new seq_idx is detected, we will stop passing the prev_state
# and switch accordingly to the init_state corresponding to the new seq_idx.
# and switch accordingly to the init_state corresponding to the new seq_idx.
# - We will also make sure that the dA_cumsum is taken only from the start of the
# sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
# - this will ensure that states will be updated with the rightmost flushed seq_idx
# of the previous chunk. This implies that the first chunk of states is either 0
# or equal to init_states of the first example.
states
=
_state_passing_fwd
(
states
=
_state_passing_fwd
(
rearrange
(
states
,
"... p n -> ... (p n)"
),
rearrange
(
states
,
"... p n -> ... (p n)"
),
dA_cumsum
,
# (nheads, nchunks, chunk_size)
dA_cumsum
,
# (nheads, nchunks, chunk_size)
cu_chunk_seqlens
,
initial_states
=
rearrange
(
initial_states
,
"... p n -> ... (p n)"
)
initial_states
=
rearrange
(
initial_states
,
"... p n -> ... (p n)"
)
if
initial_states
is
not
None
else
if
initial_states
is
not
None
else
None
,
# (batch, nheads, headdim*dstate)
None
,
# (batch, nheads, headdim*dstate)
seq_idx
=
seq_idx
,
seq_idx
=
seq_idx
,
out_dtype
=
state_dtype
if
state_dtype
is
not
None
else
C
.
dtype
,
out_dtype
=
state_dtype
if
state_dtype
is
not
None
else
C
.
dtype
)
chunk_offsets
=
chunk_offsets
)
states
=
rearrange
(
states
,
"... (p n) -> ... p n"
,
n
=
dstate
)
states
=
rearrange
(
states
,
"... (p n) -> ... p n"
,
n
=
dstate
)
# 4. Compute batched matrix multiply for C_j^T B_i terms
# 4. Compute batched matrix multiply for C_j^T B_i terms
CB
=
_bmm_chunk_fwd
(
C
,
CB
=
_bmm_chunk_fwd
(
C
,
B
,
B
,
chunk_size
,
chunk_size
,
seq_idx
=
seq_idx
,
cu_chunk_seqlens
,
output_dtype
=
torch
.
float32
)
output_dtype
=
torch
.
float32
)
# 5. Scan and compute the diagonal blocks, taking into
# 5. Scan and compute the diagonal blocks, taking into
...
@@ -148,26 +143,15 @@ def _mamba_chunk_scan_combined_fwd(x,
...
@@ -148,26 +143,15 @@ def _mamba_chunk_scan_combined_fwd(x,
dA_cumsum
,
dA_cumsum
,
C
,
C
,
states
,
states
,
cu_chunk_seqlens
,
out
,
# in-place update
out
,
# in-place update
seq_idx
,
seq_idx
,
D
=
D
,
D
=
D
,
z
=
z
,
z
=
z
,
chunk_indices
=
chunk_indices
,
chunk_offsets
=
chunk_offsets
,
initial_states
=
initial_states
,
)
varlen_states
=
chunk_state_varlen
(
B
,
x
,
dt
,
dA_cumsum
,
cu_seqlens
,
states
,
initial_states
=
initial_states
,
initial_states
=
initial_states
,
)
)
return
varlen_stat
es
return
states
[
last_chunk_indic
es
]
def
mamba_chunk_scan_combined_varlen
(
def
mamba_chunk_scan_combined_varlen
(
...
@@ -178,14 +162,14 @@ def mamba_chunk_scan_combined_varlen(
...
@@ -178,14 +162,14 @@ def mamba_chunk_scan_combined_varlen(
C
,
C
,
chunk_size
,
chunk_size
,
cu_seqlens
,
cu_seqlens
,
cu_chunk_seqlens
,
last_chunk_indices
,
seq_idx
,
seq_idx
,
out
,
out
,
D
=
None
,
D
=
None
,
z
=
None
,
z
=
None
,
dt_bias
=
None
,
dt_bias
=
None
,
initial_states
=
None
,
initial_states
=
None
,
chunk_indices
=
None
,
chunk_offsets
=
None
,
dt_softplus
=
False
,
dt_softplus
=
False
,
dt_limit
=
(
0.0
,
float
(
"inf"
)),
dt_limit
=
(
0.0
,
float
(
"inf"
)),
state_dtype
=
None
,
state_dtype
=
None
,
...
@@ -198,8 +182,10 @@ def mamba_chunk_scan_combined_varlen(
...
@@ -198,8 +182,10 @@ def mamba_chunk_scan_combined_varlen(
B: (seqlen, ngroups, dstate)
B: (seqlen, ngroups, dstate)
C: (seqlen, ngroups, dstate)
C: (seqlen, ngroups, dstate)
chunk_size: int
chunk_size: int
seq_idx: (seqlen)
cu_seqlens: (batch + 1,)
cu_seqlens: (batch + 1)
cu_chunk_seqlens: (nchunks + 1,)
last_chunk_indices: (batch,)
seq_idx: (nchunks,)
out: (seqlen, nheads, headdim) preallocated output tensor
out: (seqlen, nheads, headdim) preallocated output tensor
D: (nheads, headdim) or (nheads,)
D: (nheads, headdim) or (nheads,)
z: (seqlen, nheads, headdim)
z: (seqlen, nheads, headdim)
...
@@ -228,9 +214,9 @@ def mamba_chunk_scan_combined_varlen(
...
@@ -228,9 +214,9 @@ def mamba_chunk_scan_combined_varlen(
dt_bias
=
dt_bias
,
dt_bias
=
dt_bias
,
initial_states
=
initial_states
,
initial_states
=
initial_states
,
seq_idx
=
seq_idx
,
seq_idx
=
seq_idx
,
chunk_indices
=
chunk_indices
,
chunk_offsets
=
chunk_offsets
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
cu_chunk_seqlens
=
cu_chunk_seqlens
,
last_chunk_indices
=
last_chunk_indices
,
dt_softplus
=
dt_softplus
,
dt_softplus
=
dt_softplus
,
dt_limit
=
dt_limit
,
dt_limit
=
dt_limit
,
state_dtype
=
state_dtype
)
state_dtype
=
state_dtype
)
...
...
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
View file @
fea3e476
...
@@ -30,8 +30,7 @@ def _state_passing_fwd_kernel(
...
@@ -30,8 +30,7 @@ def _state_passing_fwd_kernel(
dA_cs_ptr
,
dA_cs_ptr
,
initstates_ptr
,
initstates_ptr
,
seq_idx_ptr
,
seq_idx_ptr
,
chunk_offsets_ptr
,
cu_chunk_seqlens_ptr
,
chunk_meta_num
,
# Matrix dimensions
# Matrix dimensions
dim
:
tl
.
constexpr
,
dim
:
tl
.
constexpr
,
nchunks
,
nchunks
,
...
@@ -50,94 +49,52 @@ def _state_passing_fwd_kernel(
...
@@ -50,94 +49,52 @@ def _state_passing_fwd_kernel(
stride_initstates_batch
:
tl
.
int64
,
stride_initstates_batch
:
tl
.
int64
,
stride_initstates_head
:
tl
.
int64
,
stride_initstates_head
:
tl
.
int64
,
stride_initstates_dim
:
tl
.
constexpr
,
stride_initstates_dim
:
tl
.
constexpr
,
stride_seq_idx_
seqlen
:
tl
.
constexpr
,
stride_seq_idx_
chunk
:
tl
.
constexpr
,
# Meta-parameters
# Meta-parameters
HAS_INITSTATES
:
tl
.
constexpr
,
HAS_INITSTATES
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
):
pid_h
=
tl
.
program_id
(
axis
=
1
)
pid_h
=
tl
.
program_id
(
axis
=
1
)
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_m
=
tl
.
program_id
(
axis
=
0
)
states_ptr
+=
pid_h
*
stride_states_head
states_ptr
+=
pid_h
*
stride_states_head
dA_cs_ptr
+=
pid_h
*
stride_dA_cs_head
+
(
chunk_size
-
dA_cs_ptr
+=
pid_h
*
stride_dA_cs_head
+
(
chunk_size
-
1
)
*
stride_dA_cs_csize
1
)
*
stride_dA_cs_csize
out_ptr
+=
pid_h
*
stride_out_head
out_ptr
+=
pid_h
*
stride_out_head
if
HAS_INITSTATES
:
initstates_ptr
+=
pid_h
*
stride_initstates_head
offs_m
=
pid_m
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
offs_m
=
pid_m
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
states_ptrs
=
states_ptr
+
offs_m
*
stride_states_dim
states_ptrs
=
states_ptr
+
offs_m
*
stride_states_dim
out_ptrs
=
out_ptr
+
offs_m
*
stride_out_dim
out_ptrs
=
out_ptr
+
offs_m
*
stride_out_dim
# - states will be the past state of the sequence that continues on the current check
if
HAS_INITSTATES
:
if
not
HAS_INITSTATES
:
initstates_ptrs
=
initstates_ptr
\
states
=
tl
.
zeros
((
BLOCK_SIZE
,
),
dtype
=
tl
.
float32
)
+
pid_h
*
stride_initstates_head
\
else
:
+
offs_m
*
stride_initstates_dim
initstates_ptr
+=
offs_m
*
stride_initstates_dim
initstates_ptrs
=
initstates_ptr
# - for cont batches, for the first chunk mean it will be the first batch's
# init state
states
=
tl
.
load
(
initstates_ptrs
,
mask
=
offs_m
<
dim
,
states
=
tl
.
load
(
initstates_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
other
=
0.0
).
to
(
tl
.
float32
)
else
:
states
=
tl
.
zeros
((
BLOCK_SIZE
,
),
dtype
=
tl
.
float32
)
tl
.
store
(
out_ptrs
,
states
,
mask
=
offs_m
<
dim
)
prev_seq_idx
=
0
out_ptrs
+=
stride_out_chunk
for
c
in
range
(
nchunks
):
prev_seq_idx_chunk_end
=
0
logical_chunk_idx
=
0
for
c
in
range
(
nchunks
-
1
):
new_states
=
tl
.
load
(
states_ptrs
,
mask
=
offs_m
<
dim
,
new_states
=
tl
.
load
(
states_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
other
=
0.0
).
to
(
tl
.
float32
)
dA_cs
=
tl
.
load
(
dA_cs_ptr
).
to
(
tl
.
float32
)
dA_cs
=
tl
.
load
(
dA_cs_ptr
).
to
(
tl
.
float32
)
scale_mask
=
True
seq_idx
=
tl
.
load
(
seq_idx_ptr
+
c
*
stride_seq_idx_chunk
)
# - the seq to pass forward is the one that is flushed to the right
# we have started a new sequence
# boundary.
if
prev_seq_idx
!=
seq_idx
:
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
seq_idx_chunk_end
=
tl
.
load
(
seq_idx_ptr
+
(
min
((
c
+
1
)
*
chunk_size
,
seqlen
)
-
1
)
*
stride_seq_idx_seqlen
)
if
HAS_INITSTATES
:
if
HAS_INITSTATES
:
if
prev_seq_idx_chunk_end
!=
seq_idx_chunk_end
:
initstates_ptrs
=
initstates_ptr
+
seq_idx
*
stride_initstates_batch
\
# this means in the current chunk the rightmost flushed seq
+
pid_h
*
stride_initstates_head
\
# has changed.
+
offs_m
*
stride_initstates_dim
# - so we do not propagate the state from previous chunk
# - but rather we load that sequence's init state
initstates_ptrs
=
initstates_ptr
+
seq_idx_chunk_end
*
stride_initstates_batch
# - update state with seq_idx_new's init state
states
=
tl
.
load
(
initstates_ptrs
,
mask
=
offs_m
<
dim
,
states
=
tl
.
load
(
initstates_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
other
=
0.0
).
to
(
tl
.
float32
)
# - we need to consider the cumsum only of the last sequence in the chunk
# - find its starting position (given by c_off of the logical chunk index)
# - and subtract the cumsum just before that position from the total cumsum
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
# sequence index at the start of the current chunk
seq_idx_chunk_start
=
tl
.
load
(
seq_idx_ptr
+
min
(
c
*
chunk_size
,
seqlen
)
*
stride_seq_idx_seqlen
)
logical_chunk_idx
+=
seq_idx_chunk_end
-
seq_idx_chunk_start
# - load the chunk offset:
c_off
=
tl
.
load
(
chunk_offsets_ptr
+
logical_chunk_idx
,
mask
=
logical_chunk_idx
<
chunk_meta_num
,
other
=
0
)
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
if
c_off
>
0
:
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
dA_cs_boundary
=
tl
.
load
(
dA_cs_ptr
-
(
chunk_size
-
1
)
*
stride_dA_cs_csize
+
(
c_off
-
1
)
*
stride_dA_cs_csize
,
mask
=
(
c_off
-
1
)
>
-
1
and
c_off
<
chunk_size
,
other
=
0.0
)
dA_cs
-=
dA_cs_boundary
# - increment logical chunk index for every physical chunk
logical_chunk_idx
+=
1
else
:
else
:
scale_mask
=
seq_idx_chunk_end
==
prev_seq_idx_chunk_end
states
=
tl
.
zeros
((
BLOCK_SIZE
,
),
dtype
=
tl
.
float32
)
prev_seq_idx_chunk_end
=
seq_idx_chunk_end
scale
=
tl
.
where
(
scale_mask
,
tl
.
exp
(
dA_cs
),
0.0
)
prev_seq_idx
=
seq_idx
states
=
scale
*
states
+
new_states
states
=
tl
.
exp
(
dA_cs
)
*
states
+
new_states
tl
.
store
(
out_ptrs
,
states
,
mask
=
offs_m
<
dim
)
tl
.
store
(
out_ptrs
,
states
,
mask
=
offs_m
<
dim
)
states_ptrs
+=
stride_states_chunk
states_ptrs
+=
stride_states_chunk
...
@@ -148,8 +105,8 @@ def _state_passing_fwd_kernel(
...
@@ -148,8 +105,8 @@ def _state_passing_fwd_kernel(
def
_state_passing_fwd
(
def
_state_passing_fwd
(
states
,
states
,
dA_cumsum
,
dA_cumsum
,
cu_chunk_seqlens
,
seq_idx
,
seq_idx
,
chunk_offsets
,
initial_states
=
None
,
initial_states
=
None
,
out_dtype
=
None
,
out_dtype
=
None
,
):
):
...
@@ -175,9 +132,7 @@ def _state_passing_fwd(
...
@@ -175,9 +132,7 @@ def _state_passing_fwd(
dA_cs_ptr
=
dA_cumsum
,
dA_cs_ptr
=
dA_cumsum
,
initstates_ptr
=
initial_states
,
initstates_ptr
=
initial_states
,
seq_idx_ptr
=
seq_idx
,
seq_idx_ptr
=
seq_idx
,
chunk_offsets_ptr
=
chunk_offsets
,
cu_chunk_seqlens_ptr
=
cu_chunk_seqlens
,
chunk_meta_num
=
len
(
chunk_offsets
)
if
chunk_offsets
is
not
None
else
0
,
dim
=
dim
,
dim
=
dim
,
nchunks
=
nchunks
,
nchunks
=
nchunks
,
seqlen
=
seqlen
if
seq_idx
is
not
None
else
0
,
seqlen
=
seqlen
if
seq_idx
is
not
None
else
0
,
...
@@ -194,7 +149,7 @@ def _state_passing_fwd(
...
@@ -194,7 +149,7 @@ def _state_passing_fwd(
stride_initstates_batch
=
initial_states_strides
[
0
],
stride_initstates_batch
=
initial_states_strides
[
0
],
stride_initstates_head
=
initial_states_strides
[
1
],
stride_initstates_head
=
initial_states_strides
[
1
],
stride_initstates_dim
=
initial_states_strides
[
2
],
stride_initstates_dim
=
initial_states_strides
[
2
],
stride_seq_idx_
seqlen
=
seq_idx
.
stride
(
0
),
stride_seq_idx_
chunk
=
seq_idx
.
stride
(
0
),
HAS_INITSTATES
=
initial_states
is
not
None
,
HAS_INITSTATES
=
initial_states
is
not
None
,
)
)
return
out
return
out
vllm/model_executor/models/plamo2.py
View file @
fea3e476
...
@@ -260,9 +260,9 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
...
@@ -260,9 +260,9 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
prep_initial_states
=
attn_metadata
.
prep_initial_states
prep_initial_states
=
attn_metadata
.
prep_initial_states
chunk_size
=
attn_metadata
.
chunk_size
chunk_size
=
attn_metadata
.
chunk_size
seq_idx_p
=
attn_metadata
.
seq_idx_p
seq_idx_p
=
attn_metadata
.
seq_idx_p
chunk_indices_p
=
attn_metadata
.
chunk_indices_p
chunk_offsets_p
=
attn_metadata
.
chunk_offsets_p
query_start_loc_p
=
attn_metadata
.
query_start_loc_p
query_start_loc_p
=
attn_metadata
.
query_start_loc_p
cu_chunk_seqlen_p
=
attn_metadata
.
cu_chunk_seqlen_p
last_chunk_indices_p
=
attn_metadata
.
last_chunk_indices_p
# 1. Gated MLP's linear projection
# 1. Gated MLP's linear projection
projected_states
=
self
.
in_proj
(
hidden_states
)
projected_states
=
self
.
in_proj
(
hidden_states
)
...
@@ -368,9 +368,9 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
...
@@ -368,9 +368,9 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self
.
num_heads
//
self
.
tp_size
,
self
.
head_dim
),
self
.
num_heads
//
self
.
tp_size
,
self
.
head_dim
),
dt_bias
=
self
.
dt_bias
,
dt_bias
=
self
.
dt_bias
,
seq_idx
=
seq_idx_p
,
seq_idx
=
seq_idx_p
,
chunk_indices
=
chunk_indices_p
,
chunk_offsets
=
chunk_offsets_p
,
cu_seqlens
=
query_start_loc_p
,
cu_seqlens
=
query_start_loc_p
,
cu_chunk_seqlens
=
cu_chunk_seqlen_p
,
last_chunk_indices
=
last_chunk_indices_p
,
initial_states
=
initial_states
,
initial_states
=
initial_states
,
dt_softplus
=
True
,
dt_softplus
=
True
,
dt_limit
=
(
0.0
,
float
(
"inf"
)),
dt_limit
=
(
0.0
,
float
(
"inf"
)),
...
...
vllm/v1/attention/backends/mamba2_attn.py
View file @
fea3e476
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
from
typing
import
Optional
...
@@ -8,6 +7,7 @@ import torch
...
@@ -8,6 +7,7 @@ import torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.utils
import
cdiv
from
vllm.v1.attention.backends.mamba_attn
import
(
from
vllm.v1.attention.backends.mamba_attn
import
(
BaseMambaAttentionMetadataBuilder
)
BaseMambaAttentionMetadataBuilder
)
from
vllm.v1.attention.backends.utils
import
(
PAD_SLOT_ID
,
from
vllm.v1.attention.backends.utils
import
(
PAD_SLOT_ID
,
...
@@ -17,91 +17,6 @@ from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
...
@@ -17,91 +17,6 @@ from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
def
_query_start_loc_to_chunk_indices_offsets
(
query_start_loc
:
torch
.
Tensor
,
chunk_size
:
int
,
total_seqlens
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Args:
query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
lengths, shape (num_seqs + 1,).
The first element should be 0. Each entry represents the starting
index of a sequence in the flattened token array.
chunk_size (int): The size of each physical mamba chunk
(number of tokens per chunk).
total_seqlens (int): The total number of tokens in the batch.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- chunk_indices (torch.Tensor): 1D tensor of indices
indicating the physical chunk for each logical chunk.
- chunk_offsets (torch.Tensor): 1D tensor of offsets
indicating the starting index of each logical chunk within
its physical chunk.
This function computes the chunk indices and offsets for the given
query_start_loc and chunk_size. Both are tensors of integers with length N,
where N is the number of logical (pseudo) chunks.
A logical chunk is a sequence of tokens that are all part of the same
sequence and are all in the same physical mamba chunk.
In other words, a logical chunk changes every time we cross a sequence
boundary or a physical mamba chunk boundary.
Logical chunks are needed to handle batched requests with initial states
(see _state_passing_fwd and _chunk_scan_fwd).
The chunk_indices tensor contains the index of the physical chunk for each
logical chunk.
The chunk_offsets tensor contains the offset (AKA starting index) of the
logical chunk in the physical chunk.
Example:
query_start_loc = [0, 5, 10]
chunk_size = 8
total_seqlens = 10
-> chunk_indices = [0, 0, 1]
-> chunk_offsets = [0, 5, 0]
In this example, we have 2 sequences, each with 5 tokens. The physical
chunk size is 8 tokens.
We have three logical chunks:
- the first logical chunk starts at token 0 in the first physical chunk
and contains all 5 tokens from the first sequence
- the second logical chunk starts at token 5 in the first physical chunk
and contains first 3 tokens from the second sequence
- the third logical chunk starts at token 0 in the second physical chunk
and contains the remaining 2 tokens from the second sequence
"""
cu_seqlens
=
query_start_loc
[
1
:]
# remove prepended 0
# outputs will have length expansion of chunks that do not divide
# chunk_size
N
=
math
.
ceil
(
total_seqlens
/
chunk_size
)
+
(
cu_seqlens
[:
-
1
]
%
chunk_size
>
0
).
sum
()
chunk_indices
=
torch
.
arange
(
N
,
dtype
=
torch
.
int
,
device
=
query_start_loc
.
device
)
chunk_offsets
=
torch
.
zeros
((
N
,
),
dtype
=
torch
.
int
,
device
=
query_start_loc
.
device
)
p
=
0
# num of insertions
for
s
,
e
in
zip
(
cu_seqlens
[:
-
1
],
cu_seqlens
[
1
:]):
# if does not divide chunk_size, then there is one chunk insertion
p
+=
(
s
%
chunk_size
>
0
)
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s
,
_e
=
s
//
chunk_size
+
p
,
e
//
chunk_size
+
p
+
(
e
%
chunk_size
>
0
)
# adjust indices and offsets
chunk_indices
[
_s
:
_e
]
-=
p
chunk_offsets
[
_s
]
=
s
%
chunk_size
return
chunk_indices
,
chunk_offsets
class
Mamba2AttentionBackend
(
AttentionBackend
):
class
Mamba2AttentionBackend
(
AttentionBackend
):
@
staticmethod
@
staticmethod
...
@@ -125,8 +40,16 @@ class Mamba2AttentionMetadata:
...
@@ -125,8 +40,16 @@ class Mamba2AttentionMetadata:
# the batch has no prefill request.
# the batch has no prefill request.
has_initial_states_p
:
Optional
[
torch
.
Tensor
]
has_initial_states_p
:
Optional
[
torch
.
Tensor
]
seq_idx_p
:
Optional
[
torch
.
Tensor
]
seq_idx_p
:
Optional
[
torch
.
Tensor
]
chunk_indices_p
:
Optional
[
torch
.
Tensor
]
chunk_offsets_p
:
Optional
[
torch
.
Tensor
]
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
# each chunk, its offests into the varlen sequence dimension. It is defined
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
# cu_chunk_seqlen_p[i+1].
cu_chunk_seqlen_p
:
Optional
[
torch
.
Tensor
]
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
# index of the last chunk for every sequence in the (prefill) batch.
last_chunk_indices_p
:
Optional
[
torch
.
Tensor
]
state_indices_tensor
:
torch
.
Tensor
# shape: [batch,]
state_indices_tensor
:
torch
.
Tensor
# shape: [batch,]
...
@@ -151,13 +74,14 @@ class Mamba2AttentionMetadataBuilder(
...
@@ -151,13 +74,14 @@ class Mamba2AttentionMetadataBuilder(
common_attn_metadata
:
CommonAttentionMetadata
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
)
->
Mamba2AttentionMetadata
:
fast_build
:
bool
=
False
)
->
Mamba2AttentionMetadata
:
num_reqs
=
common_attn_metadata
.
num_reqs
num_reqs
=
common_attn_metadata
.
num_reqs
query_start_loc_p
=
None
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens
=
common_attn_metadata
.
seq_lens
query_start_loc_p
=
None
seq_idx_p
=
None
seq_idx_p
=
None
chunk_indices_p
,
chunk_offsets_p
=
None
,
None
cu_chunk_seqlen_p
=
None
last_chunk_indices_p
=
None
# Need flags to indicate if there are initial states
# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states_p
=
None
has_initial_states_p
=
None
prep_initial_states
=
False
prep_initial_states
=
False
...
@@ -171,7 +95,7 @@ class Mamba2AttentionMetadataBuilder(
...
@@ -171,7 +95,7 @@ class Mamba2AttentionMetadataBuilder(
common_attn_metadata
,
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
))
decode_threshold
=
self
.
reorder_batch_threshold
))
# Compute seq_idx
, chunk_indices and chunk_offsets
for prefill only
# Compute seq_idx for prefill only
if
num_prefills
>
0
:
if
num_prefills
>
0
:
#[batch,]
#[batch,]
has_initial_states_cpu
=
(
has_initial_states_cpu
=
(
...
@@ -184,21 +108,68 @@ class Mamba2AttentionMetadataBuilder(
...
@@ -184,21 +108,68 @@ class Mamba2AttentionMetadataBuilder(
query_start_loc_p
=
common_attn_metadata
.
query_start_loc
[
query_start_loc_p
=
common_attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decode_tokens
-
num_prefills
-
1
:]
-
num_decode_tokens
seq_idx_p
=
torch
.
repeat_interleave
(
torch
.
arange
(
num_computed_tokens_p
=
\
num_prefills
,
common_attn_metadata
.
num_computed_tokens_cpu
[
dtype
=
torch
.
int32
,
num_reqs
-
num_prefills
:
num_reqs
]
device
=
query_start_loc_p
.
device
),
query_start_loc_p_cpu
=
common_attn_metadata
.
query_start_loc_cpu
[
query_start_loc_p
.
diff
(),
-
num_prefills
-
1
:]
-
num_decode_tokens
output_size
=
num_prefill_tokens
)
# The code below carefully constructs the chunks such that:
# We compute metadata for chunked prefill once at the top level
# 1. Chunks contain tokens from a *single* sequence only.
# model forward and reuse them in mamba layers. If not needed,
# 2. For every sequence, we are guaranteed that we can
# they will be ignored inside mamba kernels.
# retrieve the mamba state *every* chunk_size tokens.
if
prep_initial_states
:
# Constraint (1) dramatically simplifies the mamba2 kernels.
chunk_indices_p
,
chunk_offsets_p
=
(
# Constraint (2) dramatically simplifies the implementation
_query_start_loc_to_chunk_indices_offsets
(
# of prefix caching for mamba2 (wip). We need to take care
query_start_loc_p
,
self
.
chunk_size
,
# of the interaction with chunked prefill in order to
num_prefill_tokens
))
# satisfy constraint (2).
# TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen
=
[]
seq_idx
=
[]
last_chunk_indices
=
[]
seqlen_pos
=
0
for
req_idx
in
range
(
num_prefills
):
this_num_computed
=
num_computed_tokens_p
[
req_idx
].
item
()
this_new_tokens
=
query_start_loc_p_cpu
[
req_idx
+
1
].
item
(
)
-
query_start_loc_p_cpu
[
req_idx
].
item
()
# if computed tokens are not chunk-aligned, use the first
# chunk to finish it off
if
this_num_computed
%
self
.
chunk_size
!=
0
:
seq_idx
.
append
(
req_idx
)
cu_chunk_seqlen
.
append
(
seqlen_pos
)
# how many tokens to finish the chunk?
chunk_len
=
cdiv
(
this_num_computed
,
self
.
chunk_size
)
*
self
.
chunk_size
-
this_num_computed
# we can only use at most this_new_tokens
chunk_len
=
min
(
chunk_len
,
this_new_tokens
)
seqlen_pos
+=
chunk_len
this_new_tokens
-=
chunk_len
n_chunks
=
cdiv
(
this_new_tokens
,
self
.
chunk_size
)
for
chunk
in
range
(
n_chunks
):
seq_idx
.
append
(
req_idx
)
cu_chunk_seqlen
.
append
(
seqlen_pos
)
chunk_len
=
min
(
self
.
chunk_size
,
this_new_tokens
)
seqlen_pos
+=
chunk_len
this_new_tokens
-=
chunk_len
assert
this_new_tokens
==
0
last_chunk_indices
.
append
(
len
(
cu_chunk_seqlen
)
-
1
)
cu_chunk_seqlen
.
append
(
seqlen_pos
)
seq_idx_p
=
torch
.
as_tensor
(
seq_idx
,
device
=
query_start_loc_p
.
device
,
dtype
=
torch
.
int32
)
cu_chunk_seqlen_p
=
torch
.
as_tensor
(
cu_chunk_seqlen
,
device
=
query_start_loc_p
.
device
,
dtype
=
torch
.
int32
)
last_chunk_indices_p
=
torch
.
as_tensor
(
last_chunk_indices
,
device
=
query_start_loc_p
.
device
,
dtype
=
torch
.
int32
)
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
\
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
\
compute_causal_conv1d_metadata
(
query_start_loc_p
)
compute_causal_conv1d_metadata
(
query_start_loc_p
)
...
@@ -222,9 +193,9 @@ class Mamba2AttentionMetadataBuilder(
...
@@ -222,9 +193,9 @@ class Mamba2AttentionMetadataBuilder(
chunk_size
=
self
.
chunk_size
,
chunk_size
=
self
.
chunk_size
,
has_initial_states_p
=
has_initial_states_p
,
has_initial_states_p
=
has_initial_states_p
,
seq_idx_p
=
seq_idx_p
,
seq_idx_p
=
seq_idx_p
,
chunk_indices_p
=
chunk_indices_p
,
chunk_offsets_p
=
chunk_offsets_p
,
state_indices_tensor
=
state_indices_tensor
,
state_indices_tensor
=
state_indices_tensor
,
cu_chunk_seqlen_p
=
cu_chunk_seqlen_p
,
last_chunk_indices_p
=
last_chunk_indices_p
,
nums_dict
=
nums_dict
,
nums_dict
=
nums_dict
,
batch_ptr
=
batch_ptr
,
batch_ptr
=
batch_ptr
,
token_chunk_offset_ptr
=
token_chunk_offset_ptr
,
token_chunk_offset_ptr
=
token_chunk_offset_ptr
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment