Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dbb036cf
Unverified
Commit
dbb036cf
authored
Apr 15, 2025
by
Tyler Michael Smith
Committed by
GitHub
Apr 15, 2025
Browse files
[Bugfix] Fix tests/kernels/test_mamba_ssm_ssd.py (#16623)
Signed-off-by:
Tyler Michael Smith
<
tyler@neuralmagic.com
>
parent
70e7ed84
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
5 deletions
+12
-5
tests/kernels/test_mamba_ssm_ssd.py
tests/kernels/test_mamba_ssm_ssd.py
+12
-5
No files found.
tests/kernels/test_mamba_ssm_ssd.py
View file @
dbb036cf
...
...
@@ -5,6 +5,8 @@ import torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
_seq_idx_to_chunk_indices_offsets
)
from
vllm.model_executor.layers.mamba.ops.ssd_combined
import
(
mamba_chunk_scan_combined
)
from
vllm.platforms
import
current_platform
...
...
@@ -160,14 +162,14 @@ def generate_continous_batched_examples(example_lens_by_batch,
# get the metadata
cu_seqlens
=
torch
.
tensor
((
0
,
)
+
spec
,
device
=
device
).
cumsum
(
dim
=
0
)
se
d
_idx
=
torch
.
zeros
(
cu_seqlens
[
-
1
],
se
q
_idx
=
torch
.
zeros
(
cu_seqlens
[
-
1
],
dtype
=
torch
.
int32
,
device
=
cu_seqlens
.
device
)
for
i
,
(
srt
,
end
)
in
enumerate
(
zip
(
cu_seqlens
,
cu_seqlens
[
1
:],
)):
se
d
_idx
[
srt
:
end
]
=
i
se
q
_idx
[
srt
:
end
]
=
i
# for cont batch
if
IND_E
is
None
:
...
...
@@ -177,7 +179,7 @@ def generate_continous_batched_examples(example_lens_by_batch,
IND_E
=
[
end_boundary
(
x
+
y
)
for
x
,
y
in
zip
(
IND_S
,
spec
)]
yield
([
Y_min
[
s
,
IND_S
[
s
]:
IND_E
[
s
]]
for
s
in
range
(
num_examples
)],
cu_seqlens
,
se
d
_idx
.
unsqueeze
(
0
),
(
A
,
dt2
,
X2
,
B2
,
C2
))
cu_seqlens
,
se
q
_idx
.
unsqueeze
(
0
),
(
A
,
dt2
,
X2
,
B2
,
C2
))
@
pytest
.
mark
.
parametrize
(
"itype"
,
...
...
@@ -266,12 +268,15 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
exhausted
:
dict
=
{}
# map: eg -> boolean indicating example is exhausted
states
=
None
for
Y_min
,
cu_seqlens
,
se
d
_idx
,
(
A
,
dt
,
X
,
B
,
for
Y_min
,
cu_seqlens
,
se
q
_idx
,
(
A
,
dt
,
X
,
B
,
C
)
in
generate_continous_batched_examples
(
cases
,
num_examples
,
seqlen
,
last_taken
,
exhausted
,
n_heads
,
d_head
,
itype
):
chunk_indices
,
chunk_offsets
=
_seq_idx_to_chunk_indices_offsets
(
seq_idx
,
chunk_size
)
Y
,
new_states
=
mamba_chunk_scan_combined
(
X
,
dt
,
...
...
@@ -281,7 +286,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
chunk_size
,
D
=
None
,
cu_seqlens
=
cu_seqlens
,
seq_idx
=
sed_idx
,
seq_idx
=
seq_idx
,
chunk_indices
=
chunk_indices
,
chunk_offsets
=
chunk_offsets
,
return_varlen_states
=
True
,
initial_states
=
states
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment