Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ae0f69b1
Unverified
Commit
ae0f69b1
authored
Dec 08, 2025
by
roikoren755
Committed by
GitHub
Dec 08, 2025
Browse files
Add SpecDec support to `selective_state_update` (#29488)
Signed-off-by:
Roi Koren
<
roik@nvidia.com
>
parent
799804d1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
507 additions
and
74 deletions
+507
-74
tests/kernels/mamba/test_mamba_ssm.py
tests/kernels/mamba/test_mamba_ssm.py
+325
-0
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+182
-74
No files found.
tests/kernels/mamba/test_mamba_ssm.py
View file @
ae0f69b1
...
@@ -425,6 +425,80 @@ def test_selective_state_update(dim, dstate, has_z, itype):
...
@@ -425,6 +425,80 @@ def test_selective_state_update(dim, dstate, has_z, itype):
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"has_z"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"dstate"
,
[
16
,
64
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
1
,
2
,
4
])
def
test_selective_state_update_varlen
(
dim
,
dstate
,
has_z
,
itype
,
max_seq_len
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
5e-3
,
1e-2
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
5e-2
,
1.5e-1
if
torch
.
version
.
hip
:
atol
*=
2
# set seed
current_platform
.
seed_everything
(
0
)
batch_size
=
4
token_counts
=
torch
.
randint
(
1
,
max_seq_len
+
1
,
(
batch_size
,),
device
=
device
)
total_tokens
=
int
(
token_counts
.
sum
().
item
())
cu_seqlens
=
torch
.
tensor
(
[
0
]
+
torch
.
cumsum
(
token_counts
,
dim
=
0
).
tolist
(),
dtype
=
torch
.
int32
,
device
=
device
,
)
state
=
torch
.
randn
(
batch_size
,
dim
,
dstate
,
dtype
=
itype
,
device
=
device
)
x
=
torch
.
randn
(
total_tokens
,
dim
,
device
=
device
,
dtype
=
itype
)
out
=
torch
.
empty_like
(
x
)
dt
=
torch
.
randn
(
total_tokens
,
dim
,
device
=
device
,
dtype
=
itype
)
dt_bias
=
torch
.
rand
(
dim
,
device
=
device
)
-
4.0
A
=
-
torch
.
rand
(
dim
,
dstate
,
device
=
device
)
-
1.0
B
=
torch
.
randn
(
total_tokens
,
dstate
,
device
=
device
)
C
=
torch
.
randn
(
total_tokens
,
dstate
,
device
=
device
)
D
=
torch
.
randn
(
dim
,
device
=
device
)
z
=
torch
.
randn_like
(
x
)
if
has_z
else
None
state_ref
=
state
.
detach
().
clone
()
selective_state_update
(
state
,
x
,
dt
,
A
,
B
,
C
,
D
=
D
,
z
=
z
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
out
=
out
,
cu_seqlens
=
cu_seqlens
,
)
out_ref_list
=
[]
for
seq_idx
in
range
(
batch_size
):
start_idx
=
cu_seqlens
[
seq_idx
].
item
()
end_idx
=
cu_seqlens
[
seq_idx
+
1
].
item
()
num_tokens
=
end_idx
-
start_idx
for
token_idx
in
range
(
num_tokens
):
idx
=
start_idx
+
token_idx
out_ref_list
.
append
(
selective_state_update_ref
(
state_ref
[
seq_idx
:
seq_idx
+
1
],
x
[
idx
:
idx
+
1
],
dt
[
idx
:
idx
+
1
],
A
,
B
[
idx
:
idx
+
1
],
C
[
idx
:
idx
+
1
],
D
=
D
,
z
=
z
[
idx
:
idx
+
1
]
if
has_z
else
None
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
)
)
out_ref
=
torch
.
cat
(
out_ref_list
,
dim
=
0
)
assert
torch
.
allclose
(
state
,
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"wtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"wtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
,
256
,
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
,
256
,
1024
,
4096
])
...
@@ -766,3 +840,254 @@ def test_selective_state_update_with_heads_with_batch_indices(
...
@@ -766,3 +840,254 @@ def test_selective_state_update_with_heads_with_batch_indices(
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
torch
.
allclose
(
state
[
state_indices
,
:],
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
state
[
state_indices
,
:],
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"has_z"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"dstate"
,
[
16
,
64
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
4096
])
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
2
,
4
])
def
test_selective_state_update_with_num_accepted_tokens
(
dim
,
dstate
,
has_z
,
itype
,
max_seq_len
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
5e-3
,
1e-2
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
5e-2
,
1.5e-1
if
torch
.
version
.
hip
:
atol
*=
2
current_platform
.
seed_everything
(
0
)
batch_size
=
4
tokens_per_seq
=
torch
.
randint
(
1
,
max_seq_len
+
1
,
(
batch_size
,),
device
=
device
)
total_tokens
=
int
(
tokens_per_seq
.
sum
().
item
())
num_accepted_tokens
=
torch
.
randint
(
0
,
max_seq_len
,
(
batch_size
,),
device
=
device
)
num_accepted_tokens
[
0
]
=
0
# Add edge-case of no accepted tokens
num_accepted_tokens
[
1
]
=
max_seq_len
# Add edge-case of all tokens accepted
cu_seqlens
=
torch
.
tensor
(
[
0
]
+
torch
.
cumsum
(
tokens_per_seq
,
dim
=
0
).
tolist
(),
dtype
=
torch
.
int32
,
device
=
device
,
)
total_state_slots
=
50
state
=
torch
.
randn
(
total_state_slots
,
dim
,
dstate
,
dtype
=
itype
,
device
=
device
)
state_batch_indices
=
torch
.
full
(
(
batch_size
,
max_seq_len
),
PAD_SLOT_ID
,
dtype
=
torch
.
int32
,
device
=
device
)
initial_state_slots
=
torch
.
randint
(
0
,
15
,
(
batch_size
,),
device
=
device
,
dtype
=
torch
.
int32
)
for
seq_idx
in
range
(
batch_size
):
token_pos
=
max
(
num_accepted_tokens
[
seq_idx
].
item
()
-
1
,
0
)
state_batch_indices
[
seq_idx
,
token_pos
]
=
initial_state_slots
[
seq_idx
]
dst_state_batch_indices
=
torch
.
full
(
(
batch_size
,
max_seq_len
),
PAD_SLOT_ID
,
dtype
=
torch
.
int32
,
device
=
device
)
slot_offset
=
15
dst_slots_map
=
{}
for
seq_idx
in
range
(
batch_size
):
for
token_idx
in
range
(
tokens_per_seq
[
seq_idx
].
item
()):
dst_state_batch_indices
[
seq_idx
,
token_idx
]
=
slot_offset
dst_slots_map
[(
seq_idx
,
token_idx
)]
=
slot_offset
slot_offset
+=
1
x
=
torch
.
randn
(
total_tokens
,
dim
,
device
=
device
,
dtype
=
itype
)
out
=
torch
.
empty_like
(
x
)
dt
=
torch
.
randn
(
total_tokens
,
dim
,
device
=
device
,
dtype
=
itype
)
dt_bias
=
torch
.
rand
(
dim
,
device
=
device
)
-
4.0
A
=
-
torch
.
rand
(
dim
,
dstate
,
device
=
device
)
-
1.0
B
=
torch
.
randn
(
total_tokens
,
dstate
,
device
=
device
)
C
=
torch
.
randn
(
total_tokens
,
dstate
,
device
=
device
)
D
=
torch
.
randn
(
dim
,
device
=
device
)
z
=
torch
.
randn_like
(
x
)
if
has_z
else
None
state_ref_intermediate
=
{}
out_ref_list
=
[]
for
seq_idx
in
range
(
batch_size
):
seq_start
=
cu_seqlens
[
seq_idx
].
item
()
seq_end
=
cu_seqlens
[
seq_idx
+
1
].
item
()
num_tokens
=
seq_end
-
seq_start
token_pos
=
max
(
num_accepted_tokens
[
seq_idx
].
item
()
-
1
,
0
)
initial_slot
=
state_batch_indices
[
seq_idx
,
token_pos
].
item
()
state_seq
=
state
[
initial_slot
:
initial_slot
+
1
].
clone
()
for
token_idx
in
range
(
num_tokens
):
global_idx
=
seq_start
+
token_idx
out_token
=
selective_state_update_ref
(
state_seq
,
x
[
global_idx
:
global_idx
+
1
],
dt
[
global_idx
:
global_idx
+
1
],
A
,
B
[
global_idx
:
global_idx
+
1
],
C
[
global_idx
:
global_idx
+
1
],
D
=
D
,
z
=
z
[
global_idx
:
global_idx
+
1
]
if
has_z
else
None
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
)
out_ref_list
.
append
(
out_token
)
state_ref_intermediate
[(
seq_idx
,
token_idx
)]
=
state_seq
.
clone
()
out_ref
=
torch
.
cat
(
out_ref_list
,
dim
=
0
)
selective_state_update
(
state
,
x
,
dt
,
A
,
B
,
C
,
D
=
D
,
z
=
z
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
out
=
out
,
cu_seqlens
=
cu_seqlens
,
state_batch_indices
=
state_batch_indices
,
dst_state_batch_indices
=
dst_state_batch_indices
,
num_accepted_tokens
=
num_accepted_tokens
,
pad_slot_id
=
PAD_SLOT_ID
,
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
for
seq_idx
in
range
(
batch_size
):
num_tokens
=
tokens_per_seq
[
seq_idx
].
item
()
for
token_idx
in
range
(
num_tokens
):
dst_slot
=
dst_slots_map
[(
seq_idx
,
token_idx
)]
state_ref
=
state_ref_intermediate
[(
seq_idx
,
token_idx
)].
squeeze
(
0
)
assert
torch
.
allclose
(
state
[
dst_slot
],
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"has_z"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"dstate"
,
[
16
,
64
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
4096
])
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
2
,
4
])
def
test_selective_state_update_varlen_with_num_accepted
(
dim
,
dstate
,
has_z
,
itype
,
max_seq_len
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
5e-3
,
1e-2
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
5e-2
,
1.5e-1
if
torch
.
version
.
hip
:
atol
*=
2
current_platform
.
seed_everything
(
0
)
batch_size
=
4
tokens_per_seq
=
torch
.
randint
(
1
,
max_seq_len
+
1
,
(
batch_size
,),
device
=
device
)
total_tokens
=
int
(
tokens_per_seq
.
sum
().
item
())
num_accepted_tokens
=
torch
.
randint
(
0
,
max_seq_len
,
(
batch_size
,),
device
=
device
)
num_accepted_tokens
[
0
]
=
0
# Add edge-case of no accepted tokens
num_accepted_tokens
[
1
]
=
max_seq_len
# Add edge-case of all tokens accepted
cu_seqlens
=
torch
.
tensor
(
[
0
]
+
torch
.
cumsum
(
tokens_per_seq
,
dim
=
0
).
tolist
(),
dtype
=
torch
.
int32
,
device
=
device
,
)
total_state_slots
=
50
state
=
torch
.
randn
(
total_state_slots
,
dim
,
dstate
,
dtype
=
itype
,
device
=
device
)
state_batch_indices
=
torch
.
full
(
(
batch_size
,
max_seq_len
),
PAD_SLOT_ID
,
dtype
=
torch
.
int32
,
device
=
device
)
initial_state_slots
=
torch
.
randint
(
0
,
15
,
(
batch_size
,),
device
=
device
,
dtype
=
torch
.
int32
)
for
seq_idx
in
range
(
batch_size
):
token_pos
=
max
(
num_accepted_tokens
[
seq_idx
].
item
()
-
1
,
0
)
state_batch_indices
[
seq_idx
,
token_pos
]
=
initial_state_slots
[
seq_idx
]
dst_state_batch_indices
=
torch
.
full
(
(
batch_size
,
max_seq_len
),
PAD_SLOT_ID
,
dtype
=
torch
.
int32
,
device
=
device
)
slot_offset
=
15
dst_slots_map
=
{}
for
seq_idx
in
range
(
batch_size
):
for
token_idx
in
range
(
tokens_per_seq
[
seq_idx
].
item
()):
dst_state_batch_indices
[
seq_idx
,
token_idx
]
=
slot_offset
dst_slots_map
[(
seq_idx
,
token_idx
)]
=
slot_offset
slot_offset
+=
1
x
=
torch
.
randn
(
total_tokens
,
dim
,
device
=
device
,
dtype
=
itype
)
out
=
torch
.
empty_like
(
x
)
dt
=
torch
.
randn
(
total_tokens
,
dim
,
device
=
device
,
dtype
=
itype
)
dt_bias
=
torch
.
rand
(
dim
,
device
=
device
)
-
4.0
A
=
-
torch
.
rand
(
dim
,
dstate
,
device
=
device
)
-
1.0
B
=
torch
.
randn
(
total_tokens
,
dstate
,
device
=
device
)
C
=
torch
.
randn
(
total_tokens
,
dstate
,
device
=
device
)
D
=
torch
.
randn
(
dim
,
device
=
device
)
z
=
torch
.
randn_like
(
x
)
if
has_z
else
None
state_ref_intermediate
=
{}
for
seq_idx
in
range
(
batch_size
):
seq_start
=
cu_seqlens
[
seq_idx
].
item
()
seq_end
=
cu_seqlens
[
seq_idx
+
1
].
item
()
num_tokens
=
seq_end
-
seq_start
token_pos
=
max
(
num_accepted_tokens
[
seq_idx
].
item
()
-
1
,
0
)
initial_slot
=
state_batch_indices
[
seq_idx
,
token_pos
].
item
()
state_seq
=
state
[
initial_slot
:
initial_slot
+
1
].
clone
()
for
token_idx
in
range
(
num_tokens
):
global_idx
=
seq_start
+
token_idx
selective_state_update_ref
(
state_seq
,
x
[
global_idx
:
global_idx
+
1
],
dt
[
global_idx
:
global_idx
+
1
],
A
,
B
[
global_idx
:
global_idx
+
1
],
C
[
global_idx
:
global_idx
+
1
],
D
=
D
,
z
=
z
[
global_idx
:
global_idx
+
1
]
if
has_z
else
None
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
)
state_ref_intermediate
[(
seq_idx
,
token_idx
)]
=
state_seq
.
clone
()
selective_state_update
(
state
,
x
,
dt
,
A
,
B
,
C
,
D
=
D
,
z
=
z
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
out
=
out
,
cu_seqlens
=
cu_seqlens
,
state_batch_indices
=
state_batch_indices
,
dst_state_batch_indices
=
dst_state_batch_indices
,
num_accepted_tokens
=
num_accepted_tokens
,
pad_slot_id
=
PAD_SLOT_ID
,
)
for
seq_idx
in
range
(
batch_size
):
num_tokens
=
tokens_per_seq
[
seq_idx
].
item
()
for
token_idx
in
range
(
num_tokens
):
dst_slot
=
dst_slots_map
[(
seq_idx
,
token_idx
)]
state_ref
=
state_ref_intermediate
[(
seq_idx
,
token_idx
)].
squeeze
(
0
)
assert
torch
.
allclose
(
state
[
dst_slot
],
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
View file @
ae0f69b1
...
@@ -36,10 +36,14 @@ else:
...
@@ -36,10 +36,14 @@ else:
is
not
None
is
not
None
}
}
)
)
@
triton
.
heuristics
(
{
"IS_SPEC_DECODING"
:
lambda
args
:
args
[
"num_accepted_tokens_ptr"
]
is
not
None
}
)
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens_ptr"
]
is
not
None
})
@
triton
.
heuristics
(
@
triton
.
heuristics
(
{
"BLOCK_SIZE_DSTATE"
:
lambda
args
:
triton
.
next_power_of_2
(
args
[
"dstate"
])}
{
"BLOCK_SIZE_DSTATE"
:
lambda
args
:
triton
.
next_power_of_2
(
args
[
"dstate"
])}
)
)
@
triton
.
jit
@
triton
.
jit
(
do_not_specialize
=
[
"N"
])
def
_selective_scan_update_kernel
(
def
_selective_scan_update_kernel
(
# Pointers to matrices
# Pointers to matrices
state_ptr
,
state_ptr
,
...
@@ -55,8 +59,10 @@ def _selective_scan_update_kernel(
...
@@ -55,8 +59,10 @@ def _selective_scan_update_kernel(
state_batch_indices_ptr
,
state_batch_indices_ptr
,
dst_state_batch_indices_ptr
,
dst_state_batch_indices_ptr
,
pad_slot_id
,
pad_slot_id
,
num_accepted_tokens_ptr
,
cu_seqlens_ptr
,
# Matrix dimensions
# Matrix dimensions
batch
,
N
,
nheads
,
nheads
,
dim
,
dim
,
dstate
,
dstate
,
...
@@ -91,6 +97,10 @@ def _selective_scan_update_kernel(
...
@@ -91,6 +97,10 @@ def _selective_scan_update_kernel(
stride_out_batch
,
stride_out_batch
,
stride_out_head
,
stride_out_head
,
stride_out_dim
,
stride_out_dim
,
stride_state_indices_batch
,
stride_state_indices_T
,
stride_dst_state_indices_batch
,
stride_dst_state_indices_T
,
# Meta-parameters
# Meta-parameters
DT_SOFTPLUS
:
tl
.
constexpr
,
DT_SOFTPLUS
:
tl
.
constexpr
,
TIE_HDIM
:
tl
.
constexpr
,
TIE_HDIM
:
tl
.
constexpr
,
...
@@ -99,22 +109,50 @@ def _selective_scan_update_kernel(
...
@@ -99,22 +109,50 @@ def _selective_scan_update_kernel(
HAS_D
:
tl
.
constexpr
,
HAS_D
:
tl
.
constexpr
,
HAS_Z
:
tl
.
constexpr
,
HAS_Z
:
tl
.
constexpr
,
HAS_STATE_BATCH_INDICES
:
tl
.
constexpr
,
HAS_STATE_BATCH_INDICES
:
tl
.
constexpr
,
IS_SPEC_DECODING
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
BLOCK_SIZE_DSTATE
:
tl
.
constexpr
,
BLOCK_SIZE_DSTATE
:
tl
.
constexpr
,
):
):
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_b
=
tl
.
program_id
(
axis
=
1
)
pid_b
=
tl
.
program_id
(
axis
=
1
)
pid_h
=
tl
.
program_id
(
axis
=
2
)
pid_h
=
tl
.
program_id
(
axis
=
2
)
if
IS_VARLEN
:
bos
=
tl
.
load
(
cu_seqlens_ptr
+
pid_b
).
to
(
tl
.
int64
)
eos
=
tl
.
load
(
cu_seqlens_ptr
+
pid_b
+
1
).
to
(
tl
.
int64
)
seq_len
=
eos
-
bos
if
seq_len
==
0
:
return
else
:
bos
=
pid_b
seq_len
=
1
state_ptr_base
=
state_ptr
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
# is the same as the batch id.
# is the same as the batch id.
if
HAS_STATE_BATCH_INDICES
:
if
HAS_STATE_BATCH_INDICES
:
dst_state_batch_indices_ptr
+=
pid_b
if
IS_SPEC_DECODING
:
dst_state_batch_idx
=
tl
.
load
(
dst_state_batch_indices_ptr
).
to
(
tl
.
int64
)
num_accepted
=
tl
.
load
(
num_accepted_tokens_ptr
+
pid_b
).
to
(
tl
.
int64
)
init_token_idx
=
tl
.
maximum
(
num_accepted
-
1
,
0
)
else
:
init_token_idx
=
0
dst_state_batch_indices_ptr
+=
pid_b
*
stride_dst_state_indices_batch
if
not
IS_SPEC_DECODING
:
dst_state_batch_idx
=
tl
.
load
(
dst_state_batch_indices_ptr
+
init_token_idx
*
stride_dst_state_indices_T
).
to
(
tl
.
int64
)
dst_state_ptr
=
state_ptr
+
(
dst_state_ptr
=
state_ptr
+
(
dst_state_batch_idx
*
stride_state_batch
+
pid_h
*
stride_state_head
dst_state_batch_idx
*
stride_state_batch
+
pid_h
*
stride_state_head
)
)
state_batch_indices_ptr
+=
pid_b
state_batch_indices_ptr
+=
(
pid_b
*
stride_state_indices_batch
+
init_token_idx
*
stride_state_indices_T
)
state_batch_idx
=
tl
.
load
(
state_batch_indices_ptr
).
to
(
tl
.
int64
)
state_batch_idx
=
tl
.
load
(
state_batch_indices_ptr
).
to
(
tl
.
int64
)
state_ptr
+=
state_batch_idx
*
stride_state_batch
+
pid_h
*
stride_state_head
state_ptr
+=
state_batch_idx
*
stride_state_batch
+
pid_h
*
stride_state_head
else
:
else
:
...
@@ -123,45 +161,47 @@ def _selective_scan_update_kernel(
...
@@ -123,45 +161,47 @@ def _selective_scan_update_kernel(
)
)
state_ptr
+=
pid_b
*
stride_state_batch
+
pid_h
*
stride_state_head
state_ptr
+=
pid_b
*
stride_state_batch
+
pid_h
*
stride_state_head
x_ptr
+=
pid_b
*
stride_x_batch
+
pid_h
*
stride_x_head
x_ptr
+=
bos
*
stride_x_batch
+
pid_h
*
stride_x_head
dt_ptr
+=
pid_b
*
stride_dt_batch
+
pid_h
*
stride_dt_head
dt_ptr
+=
bos
*
stride_dt_batch
+
pid_h
*
stride_dt_head
if
HAS_DT_BIAS
:
if
HAS_DT_BIAS
:
dt_bias_ptr
+=
pid_h
*
stride_dt_bias_head
dt_bias_ptr
+=
pid_h
*
stride_dt_bias_head
A_ptr
+=
pid_h
*
stride_A_head
A_ptr
+=
pid_h
*
stride_A_head
B_ptr
+=
pid_b
*
stride_B_batch
+
(
pid_h
//
nheads_ngroups_ratio
)
*
stride_B_group
B_ptr
+=
bos
*
stride_B_batch
+
(
pid_h
//
nheads_ngroups_ratio
)
*
stride_B_group
C_ptr
+=
pid_b
*
stride_C_batch
+
(
pid_h
//
nheads_ngroups_ratio
)
*
stride_C_group
C_ptr
+=
bos
*
stride_C_batch
+
(
pid_h
//
nheads_ngroups_ratio
)
*
stride_C_group
if
HAS_Z
:
if
HAS_Z
:
z_ptr
+=
pid_b
*
stride_z_batch
+
pid_h
*
stride_z_head
z_ptr
+=
bos
*
stride_z_batch
+
pid_h
*
stride_z_head
out_ptr
+=
pid_b
*
stride_out_batch
+
pid_h
*
stride_out_head
out_ptr
+=
bos
*
stride_out_batch
+
pid_h
*
stride_out_head
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_SIZE_DSTATE
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_SIZE_DSTATE
)
state_ptrs
=
state_ptr
+
(
state_ptrs
=
state_ptr
+
(
offs_m
[:,
None
]
*
stride_state_dim
+
offs_n
[
None
,
:]
*
stride_state_dstate
offs_m
[:,
None
]
*
stride_state_dim
+
offs_n
[
None
,
:]
*
stride_state_dstate
)
)
if
not
IS_SPEC_DECODING
:
dst_state_ptrs
=
dst_state_ptr
+
(
dst_state_ptrs
=
dst_state_ptr
+
(
offs_m
[:,
None
]
*
stride_state_dim
+
offs_n
[
None
,
:]
*
stride_state_dstate
offs_m
[:,
None
]
*
stride_state_dim
+
offs_n
[
None
,
:]
*
stride_state_dstate
)
)
x_ptrs
=
x_ptr
+
offs_m
*
stride_x_dim
dt_ptrs
=
dt_ptr
+
offs_m
*
stride_dt_dim
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
)
if
HAS_STATE_BATCH_INDICES
:
mask
&=
state_batch_idx
!=
pad_slot_id
state
=
tl
.
load
(
state_ptrs
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_DT_BIAS
:
if
HAS_DT_BIAS
:
dt_bias_ptrs
=
dt_bias_ptr
+
offs_m
*
stride_dt_bias_dim
dt_bias_ptrs
=
dt_bias_ptr
+
offs_m
*
stride_dt_bias_dim
if
HAS_D
:
if
HAS_D
:
D_ptr
+=
pid_h
*
stride_D_head
D_ptr
+=
pid_h
*
stride_D_head
A_ptrs
=
A_ptr
+
(
D_ptrs
=
D_ptr
+
offs_m
*
stride_D_dim
offs_m
[:,
None
]
*
stride_A_dim
+
offs_n
[
None
,
:]
*
stride_A_dstate
A_ptrs
=
A_ptr
+
offs_m
[:,
None
]
*
stride_A_dim
+
offs_n
[
None
,
:]
*
stride_A_dstate
)
for
i_t
in
range
(
seq_len
):
x_ptrs
=
x_ptr
+
offs_m
*
stride_x_dim
dt_ptrs
=
dt_ptr
+
offs_m
*
stride_dt_dim
B_ptrs
=
B_ptr
+
offs_n
*
stride_B_dstate
B_ptrs
=
B_ptr
+
offs_n
*
stride_B_dstate
C_ptrs
=
C_ptr
+
offs_n
*
stride_C_dstate
C_ptrs
=
C_ptr
+
offs_n
*
stride_C_dstate
if
HAS_D
:
D_ptrs
=
D_ptr
+
offs_m
*
stride_D_dim
if
HAS_Z
:
if
HAS_Z
:
z_ptrs
=
z_ptr
+
offs_m
*
stride_z_dim
z_ptrs
=
z_ptr
+
offs_m
*
stride_z_dim
out_ptrs
=
out_ptr
+
offs_m
*
stride_out_dim
out_ptrs
=
out_ptr
+
offs_m
*
stride_out_dim
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
)
if
HAS_STATE_BATCH_INDICES
:
mask
&=
state_batch_idx
!=
pad_slot_id
state
=
tl
.
load
(
state_ptrs
,
mask
=
mask
,
other
=
0.0
)
x
=
tl
.
load
(
x_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
x
=
tl
.
load
(
x_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
if
not
TIE_HDIM
:
if
not
TIE_HDIM
:
...
@@ -171,7 +211,9 @@ def _selective_scan_update_kernel(
...
@@ -171,7 +211,9 @@ def _selective_scan_update_kernel(
if
DT_SOFTPLUS
:
if
DT_SOFTPLUS
:
dt
=
softplus
(
dt
)
dt
=
softplus
(
dt
)
A
=
tl
.
load
(
A
=
tl
.
load
(
A_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
),
other
=
0.0
A_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
),
other
=
0.0
,
).
to
(
tl
.
float32
)
).
to
(
tl
.
float32
)
dA
=
tl
.
exp
(
A
*
dt
[:,
None
])
dA
=
tl
.
exp
(
A
*
dt
[:,
None
])
else
:
else
:
...
@@ -193,10 +235,21 @@ def _selective_scan_update_kernel(
...
@@ -193,10 +235,21 @@ def _selective_scan_update_kernel(
dB
=
B
[
None
,
:]
*
dt
[:,
None
]
if
not
TIE_HDIM
else
B
*
dt
dB
=
B
[
None
,
:]
*
dt
[:,
None
]
if
not
TIE_HDIM
else
B
*
dt
state
=
state
*
dA
+
dB
*
x
[:,
None
]
state
=
state
*
dA
+
dB
*
x
[:,
None
]
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
)
if
IS_SPEC_DECODING
:
if
HAS_STATE_BATCH_INDICES
:
dst_idx_ptr
=
dst_state_batch_indices_ptr
+
i_t
*
stride_dst_state_indices_T
mask
&=
state_batch_idx
!=
pad_slot_id
token_dst_idx
=
tl
.
load
(
dst_idx_ptr
).
to
(
tl
.
int64
)
tl
.
store
(
dst_state_ptrs
,
state
,
mask
=
mask
)
if
token_dst_idx
!=
pad_slot_id
:
token_dst_ptrs
=
(
state_ptr_base
+
token_dst_idx
*
stride_state_batch
+
pid_h
*
stride_state_head
+
offs_m
[:,
None
]
*
stride_state_dim
+
offs_n
[
None
,
:]
*
stride_state_dstate
)
tl
.
store
(
token_dst_ptrs
,
state
.
to
(
token_dst_ptrs
.
dtype
.
element_ty
),
mask
=
mask
)
out
=
tl
.
sum
(
state
*
C
[
None
,
:],
axis
=
1
)
out
=
tl
.
sum
(
state
*
C
[
None
,
:],
axis
=
1
)
if
HAS_D
:
if
HAS_D
:
out
+=
x
*
D
out
+=
x
*
D
...
@@ -204,6 +257,17 @@ def _selective_scan_update_kernel(
...
@@ -204,6 +257,17 @@ def _selective_scan_update_kernel(
out
*=
z
*
tl
.
sigmoid
(
z
)
out
*=
z
*
tl
.
sigmoid
(
z
)
tl
.
store
(
out_ptrs
,
out
,
mask
=
offs_m
<
dim
)
tl
.
store
(
out_ptrs
,
out
,
mask
=
offs_m
<
dim
)
x_ptr
+=
stride_x_batch
dt_ptr
+=
stride_dt_batch
B_ptr
+=
stride_B_batch
C_ptr
+=
stride_C_batch
out_ptr
+=
stride_out_batch
if
HAS_Z
:
z_ptr
+=
stride_z_batch
if
not
IS_SPEC_DECODING
:
tl
.
store
(
dst_state_ptrs
,
state
.
to
(
dst_state_ptrs
.
dtype
.
element_ty
),
mask
=
mask
)
def
selective_state_update
(
def
selective_state_update
(
state
,
state
,
...
@@ -220,6 +284,8 @@ def selective_state_update(
...
@@ -220,6 +284,8 @@ def selective_state_update(
dst_state_batch_indices
=
None
,
dst_state_batch_indices
=
None
,
pad_slot_id
=
PAD_SLOT_ID
,
pad_slot_id
=
PAD_SLOT_ID
,
out
=
None
,
out
=
None
,
num_accepted_tokens
=
None
,
cu_seqlens
=
None
,
):
):
"""
"""
Argument:
Argument:
...
@@ -240,6 +306,11 @@ def selective_state_update(
...
@@ -240,6 +306,11 @@ def selective_state_update(
indices 0 and 3
indices 0 and 3
out: Preallocated ssm output tensor. Assume same shape as x.
out: Preallocated ssm output tensor. Assume same shape as x.
In-place updated.
In-place updated.
num_accepted_tokens: (batch,)
number of accepted tokens from previous verification step,
tells the kernel which initial state to use
cu_seqlens: (batch,)
length per sequence, for variable length in speculative decoding cases
"""
"""
if
state
.
dim
()
==
3
:
if
state
.
dim
()
==
3
:
state
=
state
.
unsqueeze
(
1
)
state
=
state
.
unsqueeze
(
1
)
...
@@ -261,9 +332,26 @@ def selective_state_update(
...
@@ -261,9 +332,26 @@ def selective_state_update(
dt_bias
=
dt_bias
.
unsqueeze
(
0
)
dt_bias
=
dt_bias
.
unsqueeze
(
0
)
if
out
.
dim
()
==
2
:
if
out
.
dim
()
==
2
:
out
=
out
.
unsqueeze
(
1
)
out
=
out
.
unsqueeze
(
1
)
if
num_accepted_tokens
is
not
None
:
assert
state_batch_indices
is
not
None
and
state_batch_indices
.
dim
()
==
2
assert
dst_state_batch_indices
is
None
or
dst_state_batch_indices
.
dim
()
==
2
if
state_batch_indices
is
not
None
and
state_batch_indices
.
dim
()
==
1
:
state_batch_indices
=
state_batch_indices
.
unsqueeze
(
1
)
if
dst_state_batch_indices
is
not
None
and
dst_state_batch_indices
.
dim
()
==
1
:
dst_state_batch_indices
=
dst_state_batch_indices
.
unsqueeze
(
1
)
_
,
nheads
,
dim
,
dstate
=
state
.
shape
_
,
nheads
,
dim
,
dstate
=
state
.
shape
batch
=
x
.
shape
[
0
]
batch
=
x
.
shape
[
0
]
if
cu_seqlens
is
not
None
:
N
=
len
(
cu_seqlens
)
-
1
# Only used to verify the shape of
# state_batch_indices and dst_state_batch_indices
max_seqlen
=
(
state_batch_indices
.
size
(
-
1
)
if
state_batch_indices
is
not
None
else
1
)
else
:
N
=
batch
max_seqlen
=
1
assert
x
.
shape
==
(
batch
,
nheads
,
dim
)
assert
x
.
shape
==
(
batch
,
nheads
,
dim
)
assert
dt
.
shape
==
x
.
shape
assert
dt
.
shape
==
x
.
shape
...
@@ -279,16 +367,30 @@ def selective_state_update(
...
@@ -279,16 +367,30 @@ def selective_state_update(
if
dt_bias
is
not
None
:
if
dt_bias
is
not
None
:
assert
dt_bias
.
shape
==
(
nheads
,
dim
)
assert
dt_bias
.
shape
==
(
nheads
,
dim
)
if
state_batch_indices
is
not
None
:
if
state_batch_indices
is
not
None
:
assert
state_batch_indices
.
shape
==
(
batch
,)
assert
state_batch_indices
.
shape
[
0
]
>=
N
assert
state_batch_indices
.
shape
[
1
]
>=
max_seqlen
if
dst_state_batch_indices
is
not
None
:
if
dst_state_batch_indices
is
not
None
:
assert
dst_state_batch_indices
.
shape
==
(
batch
,)
assert
dst_state_batch_indices
.
shape
[
0
]
>=
N
assert
dst_state_batch_indices
.
shape
[
1
]
>=
max_seqlen
else
:
else
:
# revert to the default behavior of in-place state updates
# revert to the default behavior of in-place state updates
dst_state_batch_indices
=
state_batch_indices
dst_state_batch_indices
=
state_batch_indices
assert
out
.
shape
==
x
.
shape
assert
out
.
shape
==
x
.
shape
if
num_accepted_tokens
is
not
None
:
assert
num_accepted_tokens
.
shape
==
(
N
,)
grid
=
lambda
META
:
(
triton
.
cdiv
(
dim
,
META
[
"BLOCK_SIZE_M"
]),
batch
,
nheads
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
dim
,
META
[
"BLOCK_SIZE_M"
]),
N
,
nheads
)
z_strides
=
(
z
.
stride
(
0
),
z
.
stride
(
1
),
z
.
stride
(
2
))
if
z
is
not
None
else
(
0
,
0
,
0
)
z_strides
=
(
z
.
stride
(
0
),
z
.
stride
(
1
),
z
.
stride
(
2
))
if
z
is
not
None
else
(
0
,
0
,
0
)
state_batch_indices_strides
=
(
(
state_batch_indices
.
stride
(
0
),
state_batch_indices
.
stride
(
1
))
if
state_batch_indices
is
not
None
else
(
0
,
0
)
)
dst_state_batch_indices_strides
=
(
(
dst_state_batch_indices
.
stride
(
0
),
dst_state_batch_indices
.
stride
(
1
))
if
dst_state_batch_indices
is
not
None
else
(
0
,
0
)
)
# We don't want autotune since it will overwrite the state
# We don't want autotune since it will overwrite the state
# We instead tune by hand.
# We instead tune by hand.
BLOCK_SIZE_M
,
num_warps
=
(
BLOCK_SIZE_M
,
num_warps
=
(
...
@@ -321,7 +423,9 @@ def selective_state_update(
...
@@ -321,7 +423,9 @@ def selective_state_update(
state_batch_indices
,
state_batch_indices
,
dst_state_batch_indices
,
dst_state_batch_indices
,
pad_slot_id
,
pad_slot_id
,
batch
,
num_accepted_tokens
,
cu_seqlens
,
N
,
nheads
,
nheads
,
dim
,
dim
,
dstate
,
dstate
,
...
@@ -353,6 +457,10 @@ def selective_state_update(
...
@@ -353,6 +457,10 @@ def selective_state_update(
out
.
stride
(
0
),
out
.
stride
(
0
),
out
.
stride
(
1
),
out
.
stride
(
1
),
out
.
stride
(
2
),
out
.
stride
(
2
),
state_batch_indices_strides
[
0
],
state_batch_indices_strides
[
1
],
dst_state_batch_indices_strides
[
0
],
dst_state_batch_indices_strides
[
1
],
dt_softplus
,
dt_softplus
,
tie_hdim
,
tie_hdim
,
BLOCK_SIZE_M
,
BLOCK_SIZE_M
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment