Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
300892fe
Commit
300892fe
authored
Dec 23, 2024
by
zhuwenwen
Browse files
update test_cutlass.py
parent
c56b26cd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
3 deletions
+4
-3
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+4
-3
No files found.
tests/kernels/test_cutlass.py
View file @
300892fe
...
@@ -128,9 +128,10 @@ def cutlass_int8_gemm_helper(m: int,
...
@@ -128,9 +128,10 @@ def cutlass_int8_gemm_helper(m: int,
elif
torch_version
.
startswith
(
"2.4"
):
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
from
tests.kernels.utils
import
opcheck
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
# opcheck(torch.ops._C.cutlass_scaled_mm,
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
# (out, a, b, scale_a, scale_b, bias))
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
# @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
# @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment