Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6a2d659d
Unverified
Commit
6a2d659d
authored
Jun 28, 2024
by
Tyler Michael Smith
Committed by
GitHub
Jun 28, 2024
Browse files
[Bugfix] Fix compute datatype for cutlass 3.x epilogues (#5931)
parent
b2c62023
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
59 deletions
+70
-59
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+2
-2
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+68
-57
No files found.
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
6a2d659d
...
@@ -144,14 +144,14 @@ struct ScaledEpilogueBias
...
@@ -144,14 +144,14 @@ struct ScaledEpilogueBias
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
ElementD
,
ElementD
,
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
using
EVTCompute0
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
ScaleB
,
Accum
>
;
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementD
,
ElementD
,
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
BiasDescriptor
=
using
BiasDescriptor
=
...
...
tests/kernels/test_cutlass.py
View file @
6a2d659d
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
Run `pytest tests/kernels/test_cutlass.py`.
Run `pytest tests/kernels/test_cutlass.py`.
"""
"""
from
typing
import
Type
from
typing
import
Optional
,
Type
import
pytest
import
pytest
import
torch
import
torch
...
@@ -27,12 +27,27 @@ def to_int8(tensor: torch.Tensor):
...
@@ -27,12 +27,27 @@ def to_int8(tensor: torch.Tensor):
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
output
=
(
scale_a
*
(
scale_b
*
(
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))))).
to
(
out_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
def
cutlass_fp8_gemm_helper
(
m
:
int
,
def
cutlass_fp8_gemm_helper
(
m
:
int
,
n
:
int
,
n
:
int
,
k
:
int
,
k
:
int
,
per_token_act_quant
:
bool
,
per_token_act_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
bias
:
bool
,
use_
bias
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
device
:
str
=
"cuda"
):
# Test for a cutlass kernel with per-token activation quantization
# Test for a cutlass kernel with per-token activation quantization
...
@@ -43,23 +58,19 @@ def cutlass_fp8_gemm_helper(m: int,
...
@@ -43,23 +58,19 @@ def cutlass_fp8_gemm_helper(m: int,
m_a_scales
=
m
if
per_token_act_quant
else
1
m_a_scales
=
m
if
per_token_act_quant
else
1
n_b_scales
=
n
if
per_out_channel_weight_quant
else
1
n_b_scales
=
n
if
per_out_channel_weight_quant
else
1
scale_a
=
(
torch
.
randn
(
scale_a
=
(
torch
.
randn
((
m_a_scales
,
1
),
device
=
device
,
(
m_a_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
/
10
)
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
(
scale_b
=
(
torch
.
randn
((
1
,
n_b_scales
),
device
=
device
,
(
1
,
n_b_scales
),
device
=
device
,
dtype
=
torch
.
float32
)
/
10
)
dtype
=
torch
.
float32
))
if
bias
:
if
use_bias
:
# bias term should be > 1 so that the absolute tolerance can catch it
bias
=
torch
.
rand
((
n
,
),
device
=
device
,
dtype
=
out_dtype
)
*
10
bias_t
=
torch
.
rand
((
n
,
),
device
=
device
,
dtype
=
out_dtype
)
+
1.0
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias_t
)
else
:
else
:
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
)
bias
=
None
bias_t
=
0
baseline
=
(
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
))
+
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
bias_t
).
to
(
out_dtype
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
1
e-
1
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
5
e-
2
)
def
cutlass_int8_gemm_helper
(
m
:
int
,
def
cutlass_int8_gemm_helper
(
m
:
int
,
...
@@ -67,7 +78,7 @@ def cutlass_int8_gemm_helper(m: int,
...
@@ -67,7 +78,7 @@ def cutlass_int8_gemm_helper(m: int,
k
:
int
,
k
:
int
,
per_token_act_quant
:
bool
,
per_token_act_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
bias
:
bool
,
use_
bias
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
device
:
str
=
"cuda"
):
# Test for a cutlass kernel with per-token activation quantization
# Test for a cutlass kernel with per-token activation quantization
...
@@ -78,22 +89,19 @@ def cutlass_int8_gemm_helper(m: int,
...
@@ -78,22 +89,19 @@ def cutlass_int8_gemm_helper(m: int,
m_a_scales
=
m
if
per_token_act_quant
else
1
m_a_scales
=
m
if
per_token_act_quant
else
1
n_b_scales
=
n
if
per_out_channel_weight_quant
else
1
n_b_scales
=
n
if
per_out_channel_weight_quant
else
1
scale_a
=
(
torch
.
randn
(
scale_a
=
(
torch
.
randn
(
(
m_a_scales
,
1
),
device
=
device
,
(
m_a_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
/
10
)
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
(
scale_b
=
(
torch
.
randn
(
(
1
,
n_b_scales
),
device
=
device
,
(
1
,
n_b_scales
),
device
=
device
,
dtype
=
torch
.
float32
)
/
10
)
dtype
=
torch
.
float32
))
if
bias
:
if
use_bias
:
# bias term should be > 1 so that the absolute tolerance can catch it
bias
=
torch
.
rand
((
n
,
),
device
=
device
,
dtype
=
out_dtype
)
*
10
bias_t
=
torch
.
rand
((
n
,
),
device
=
device
,
dtype
=
out_dtype
)
+
1.0
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias_t
)
else
:
else
:
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
)
bias
=
None
bias_t
=
0
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
(
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
))
+
bias_t
).
to
(
dtype
=
out_dtype
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
...
@@ -102,12 +110,12 @@ def cutlass_int8_gemm_helper(m: int,
...
@@ -102,12 +110,12 @@ def cutlass_int8_gemm_helper(m: int,
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"
use_
bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
capability
<
89
,
@
pytest
.
mark
.
skipif
(
capability
<
89
,
reason
=
"FP8 is not supported on this GPU type."
)
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
def
test_cutlass_fp8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
bias
:
bool
):
per_out_ch
:
bool
,
use_
bias
:
bool
):
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
bias
)
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
use_
bias
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
512
,
222
,
33
,
1
])
...
@@ -115,70 +123,70 @@ def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
...
@@ -115,70 +123,70 @@ def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"
use_
bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
def
test_cutlass_int8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
bias
:
bool
):
per_out_ch
:
bool
,
use_
bias
:
bool
):
cutlass_int8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
bias
)
cutlass_int8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
use_
bias
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"
use_
bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_int8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
],
out_dtype
:
Type
[
torch
.
dtype
],
bias
:
bool
):
use_
bias
:
bool
):
cutlass_int8_gemm_helper
(
512
,
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
512
,
512
,
per_act_token
,
per_act_token
,
per_out_ch
,
per_out_ch
,
bias
,
use_
bias
,
out_dtype
=
out_dtype
)
out_dtype
=
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"
use_
bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
capability
<
89
,
@
pytest
.
mark
.
skipif
(
capability
<
89
,
reason
=
"FP8 is not supported on this GPU type."
)
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_fp8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
],
out_dtype
:
Type
[
torch
.
dtype
],
bias
:
bool
):
use_
bias
:
bool
):
cutlass_fp8_gemm_helper
(
512
,
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
512
,
512
,
per_act_token
,
per_act_token
,
per_out_ch
,
per_out_ch
,
bias
,
use_
bias
,
out_dtype
=
out_dtype
)
out_dtype
=
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"
use_
bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
skipif
(
capability
<
89
,
@
pytest
.
mark
.
skipif
(
capability
<
89
,
reason
=
"FP8 is not supported on this GPU type."
)
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_fp8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
bias
:
bool
,
device
:
str
):
use_
bias
:
bool
,
device
:
str
):
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
bias
,
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_
bias
,
torch
.
bfloat16
,
device
)
torch
.
bfloat16
,
device
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"
use_
bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_cutlass_int8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_int8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
bias
:
bool
,
device
:
str
):
use_
bias
:
bool
,
device
:
str
):
cutlass_int8_gemm_helper
(
512
,
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
512
,
512
,
per_act_token
,
per_act_token
,
per_out_ch
,
per_out_ch
,
bias
,
use_
bias
,
out_dtype
=
torch
.
bfloat16
,
out_dtype
=
torch
.
bfloat16
,
device
=
device
)
device
=
device
)
...
@@ -190,25 +198,26 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
...
@@ -190,25 +198,26 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
# kernel must handle any M thrown at it.
# kernel must handle any M thrown at it.
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"
use_
bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
capability
<
89
,
@
pytest
.
mark
.
skipif
(
capability
<
89
,
reason
=
"FP8 is not supported on this GPU type."
)
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_fp8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
bias
:
bool
):
use_
bias
:
bool
):
for
nk
in
range
(
32
,
128
,
32
):
for
nk
in
range
(
32
,
128
,
32
):
for
m
in
range
(
1
,
128
):
for
m
in
range
(
1
,
128
):
cutlass_fp8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
,
bias
)
cutlass_fp8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"
use_
bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_int8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
bias
:
bool
):
use_
bias
:
bool
):
for
nk
in
range
(
32
,
128
,
32
):
for
nk
in
range
(
32
,
128
,
32
):
for
m
in
range
(
1
,
128
):
for
m
in
range
(
1
,
128
):
cutlass_int8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
,
cutlass_int8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
,
bias
)
use_
bias
)
# Test working with a subset of A and B
# Test working with a subset of A and B
...
@@ -229,9 +238,11 @@ def test_cutlass_subset():
...
@@ -229,9 +238,11 @@ def test_cutlass_subset():
scale_a
,
scale_a
,
scale_b
,
scale_b
,
out_dtype
=
torch
.
bfloat16
)
out_dtype
=
torch
.
bfloat16
)
baseline
=
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
baseline
=
baseline_scaled_mm
(
a
,
scale_b
*
b
,
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
dtype
=
torch
.
bfloat16
)
scale_a
,
scale_b
,
out_dtype
=
torch
.
bfloat16
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment