Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
32ec9e2f
Unverified
Commit
32ec9e2f
authored
Jul 23, 2025
by
Yu Chin Fabian Lim
Committed by
GitHub
Jul 23, 2025
Browse files
Mamba V2 Test not Asserting Failures. (#21379)
Signed-off-by:
Yu Chin Fabian Lim
<
flim@sg.ibm.com
>
parent
accac829
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
10 deletions
+25
-10
tests/kernels/mamba/test_mamba_mixer2.py
tests/kernels/mamba/test_mamba_mixer2.py
+5
-4
tests/kernels/mamba/test_mamba_ssm_ssd.py
tests/kernels/mamba/test_mamba_ssm_ssd.py
+20
-6
No files found.
tests/kernels/mamba/test_mamba_mixer2.py
View file @
32ec9e2f
...
@@ -119,7 +119,8 @@ def mixer2_gated_norm_tensor_parallel(
...
@@ -119,7 +119,8 @@ def mixer2_gated_norm_tensor_parallel(
gate_states
[...,
local_rank
*
N
:(
local_rank
+
1
)
*
N
],
gate_states
[...,
local_rank
*
N
:(
local_rank
+
1
)
*
N
],
)
)
ref_output
=
mixer_single_gpu
(
hidden_states
,
gate_states
)
ref_output
=
mixer_single_gpu
(
hidden_states
,
gate_states
)
torch
.
allclose
(
output
,
torch
.
testing
.
assert_close
(
output
,
ref_output
[...,
local_rank
*
N
:(
local_rank
+
1
)
*
N
],
ref_output
[...,
atol
=
1e-3
,
local_rank
*
N
:(
local_rank
+
1
)
*
N
],
rtol
=
1e-3
)
atol
=
5e-3
,
rtol
=
1e-3
)
tests/kernels/mamba/test_mamba_ssm_ssd.py
View file @
32ec9e2f
...
@@ -193,6 +193,13 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
...
@@ -193,6 +193,13 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
# this tests the kernels on a single example (no batching)
# this tests the kernels on a single example (no batching)
# TODO: the bfloat16 case requires higher thresholds. To be investigated
if
itype
==
torch
.
bfloat16
:
atol
,
rtol
=
5e-2
,
5e-2
else
:
atol
,
rtol
=
8e-3
,
5e-3
# set seed
# set seed
batch_size
=
1
# batch_size
batch_size
=
1
# batch_size
# ssd_minimal_discrete requires chunk_size divide seqlen
# ssd_minimal_discrete requires chunk_size divide seqlen
...
@@ -216,14 +223,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
...
@@ -216,14 +223,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
return_final_states
=
True
)
return_final_states
=
True
)
# just test the last in sequence
# just test the last in sequence
torch
.
all
close
(
Y
[:,
-
1
],
Y_min
[:,
-
1
],
atol
=
1e-3
,
rtol
=
1e-3
)
torch
.
testing
.
assert_
close
(
Y
[:,
-
1
],
Y_min
[:,
-
1
],
atol
=
atol
,
rtol
=
rtol
)
# just test the last head
# just test the last head
# NOTE, in the kernel we always cast states to fp32
# NOTE, in the kernel we always cast states to fp32
torch
.
all
close
(
final_state
[:,
-
1
],
torch
.
testing
.
assert_
close
(
final_state
[:,
-
1
],
final_state_min
[:,
-
1
].
to
(
torch
.
float32
),
final_state_min
[:,
-
1
].
to
(
torch
.
float32
),
atol
=
1e-3
,
atol
=
atol
,
rtol
=
1e-3
)
rtol
=
rtol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
])
...
@@ -263,6 +270,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
...
@@ -263,6 +270,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
seqlen
,
chunk_size
,
num_examples
,
cases
=
seq_len_chunk_size_cases
seqlen
,
chunk_size
,
num_examples
,
cases
=
seq_len_chunk_size_cases
# TODO: the irregular chunk size cases have some issues and require higher
# tolerance. This is to be invesigated
if
chunk_size
not
in
{
8
,
256
}:
atol
,
rtol
=
5e-1
,
5e-1
else
:
atol
,
rtol
=
5e-3
,
5e-3
# hold state during the cutting process so we know if an
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
# example has been exhausted and needs to cycle
last_taken
:
dict
=
{}
# map: eg -> pointer to last taken sample
last_taken
:
dict
=
{}
# map: eg -> pointer to last taken sample
...
@@ -300,7 +314,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
...
@@ -300,7 +314,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
# just test one dim and dstate
# just test one dim and dstate
Y_eg
=
Y
[
0
,
cu_seqlens
[
i
]:
cu_seqlens
[
i
+
1
],
0
,
0
]
Y_eg
=
Y
[
0
,
cu_seqlens
[
i
]:
cu_seqlens
[
i
+
1
],
0
,
0
]
Y_min_eg
=
Y_min
[
i
][:,
0
,
0
]
Y_min_eg
=
Y_min
[
i
][:,
0
,
0
]
torch
.
all
close
(
Y_eg
,
Y_min_eg
,
atol
=
1e-3
,
rtol
=
1e-3
)
torch
.
testing
.
assert_
close
(
Y_eg
,
Y_min_eg
,
atol
=
atol
,
rtol
=
rtol
)
# update states
# update states
states
=
new_states
states
=
new_states
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment