Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
d5c88afa
"git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "6891d3ecbe93b19c939834b102f844e4da9a967f"
Unverified
Commit
d5c88afa
authored
Oct 04, 2025
by
Lei Wang
Committed by
GitHub
Oct 04, 2025
Browse files
[Example] Add correctness assert into dsa example (#937)
parent
242cb457
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
8 deletions
+52
-8
examples/deepseek_v32/sparse_mla_fwd.py
examples/deepseek_v32/sparse_mla_fwd.py
+8
-1
examples/deepseek_v32/utils.py
examples/deepseek_v32/utils.py
+44
-7
No files found.
examples/deepseek_v32/sparse_mla_fwd.py
View file @
d5c88afa
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
import
torch
import
torch
import
tilelang
import
tilelang
from
tilelang
import
language
as
T
from
tilelang
import
language
as
T
from
utils
import
assert_tensors_similar
@
tilelang
.
jit
(
@
tilelang
.
jit
(
...
@@ -253,6 +254,12 @@ def test_sparse_mla_fwd(B=1,
...
@@ -253,6 +254,12 @@ def test_sparse_mla_fwd(B=1,
tl_out
,
tl_lse
=
sparse_mla_fwd_interface
(
q
,
kv
,
indices
)
tl_out
,
tl_lse
=
sparse_mla_fwd_interface
(
q
,
kv
,
indices
)
if
SKV
<=
4096
:
# otherwise may cause out of memory
ref_out
=
ref_sparse_mla_fwd_interface
(
q
,
kv
,
indices
)
assert_tensors_similar
(
tl_out
,
ref_out
,
eps
=
1e-2
,
name
=
"out"
)
print
(
"assert_tensors_similar passed"
)
def
fn
():
def
fn
():
return
sparse_mla_fwd_interface
(
q
,
kv
,
indices
)
return
sparse_mla_fwd_interface
(
q
,
kv
,
indices
)
...
@@ -270,4 +277,4 @@ def test_sparse_mla_fwd(B=1,
...
@@ -270,4 +277,4 @@ def test_sparse_mla_fwd(B=1,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_sparse_mla_fwd
(
test_sparse_mla_fwd
(
B
=
1
,
S
=
4096
,
SKV
=
32768
,
H
=
128
,
HKV
=
1
,
DQK
=
576
,
DV
=
512
,
topk
=
2048
,
dtype
=
torch
.
bfloat16
)
B
=
1
,
S
=
4096
,
SKV
=
4096
,
H
=
128
,
HKV
=
1
,
DQK
=
576
,
DV
=
512
,
topk
=
2048
,
dtype
=
torch
.
bfloat16
)
examples/deepseek_v32/utils.py
View file @
d5c88afa
...
@@ -251,25 +251,62 @@ def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1,
...
@@ -251,25 +251,62 @@ def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1,
return
ks
,
ke
return
ks
,
ke
def
print_red_warning
(
message
):
def
calculate_tensor_similarity
(
x
,
y
,
name
=
"tensor"
):
print
(
f
"
\033
[31mWARNING:
{
message
}
\033
[0m"
)
"""
Calculate similarity between two tensors using a normalized dot product metric.
Unlike torch.testing.assert_close which uses absolute/relative tolerance based on
element-wise differences, this function computes a global similarity score:
sim = 2 * <x, y> / (||x||^2 + ||y||^2)
This metric is scale-invariant and measures the cosine-like similarity normalized
by the magnitude of both tensors. It returns 1 for identical tensors and values
closer to 0 for dissimilar ones. This is particularly useful for comparing tensors
with varying magnitudes where relative errors matter more than absolute differences.
Args:
x: First tensor to compare
y: Second tensor to compare
name: Name of the tensor for logging purposes
def
calc_sim
(
x
,
y
,
name
=
"tensor"
):
Returns:
Similarity score in range [0, 1] where 1 means identical
"""
x
,
y
=
x
.
data
.
double
(),
y
.
data
.
double
()
x
,
y
=
x
.
data
.
double
(),
y
.
data
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
if
denominator
==
0
:
if
denominator
==
0
:
print
_red_warning
(
f
'
{
name
}
all zero
'
)
print
(
f
"
\033
[33mWARNING:
{
name
}
all zero
\033
[0m"
)
return
1
return
1
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
sim
return
sim
def
assert_similar
(
x
,
y
,
eps
=
1e-8
,
name
=
"tensor"
,
raise_assert
=
True
):
def
assert_tensors_similar
(
x
,
y
,
eps
=
1e-8
,
name
=
"tensor"
,
raise_assert
=
True
):
sim
=
calc_sim
(
x
,
y
,
name
)
"""
Assert that two tensors are similar using a global similarity metric.
Key differences from torch.testing.assert_close:
- torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking
that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers
and requires all elements to satisfy the tolerance.
- assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the
normalized dot product. It's more robust to outliers and focuses on overall
tensor similarity rather than element-wise precision. This is better suited for
comparing large tensors where a few outlier elements shouldn't fail the test.
Args:
x: First tensor to compare
y: Second tensor to compare
eps: Maximum allowed difference (1 - similarity), default 1e-8
name: Name of the tensor for error messages
raise_assert: Whether to raise assertion error on failure
"""
sim
=
calculate_tensor_similarity
(
x
,
y
,
name
)
diff
=
1.
-
sim
diff
=
1.
-
sim
if
not
(
0
<=
diff
<=
eps
):
if
not
(
0
<=
diff
<=
eps
):
print_red_warning
(
f
'
{
name
}
Error:
{
diff
}
'
)
print
(
f
"
\033
[31mERROR:
{
name
}
similarity check failed, diff=
{
diff
:.
2
e
}
(threshold=
{
eps
:.
2
e
}
)
\033
[0m"
)
if
raise_assert
:
if
raise_assert
:
assert
False
# noqa: B011
assert
False
# noqa: B011
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment