Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
26513bb5
Commit
26513bb5
authored
Dec 03, 2024
by
gaoqiong
Browse files
修改cutlass 单测
parent
1a9775b8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
242 additions
and
230 deletions
+242
-230
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+240
-229
vllm/_custom_ops.py
vllm/_custom_ops.py
+2
-1
No files found.
tests/kernels/test_cutlass.py
View file @
26513bb5
...
...
@@ -9,14 +9,14 @@ import torch
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
#
from vllm.platforms import current_platform
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
0
}
"
#
for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
#
capability = current_platform.get_device_capability()
capability
=
90
#
capability[0] * 10 + capability[1]
def
to_fp8
(
tensor
:
torch
.
Tensor
):
...
...
@@ -39,7 +39,7 @@ def baseline_scaled_mm(a: torch.Tensor,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
output
=
(
scale_a
*
(
scale_b
*
(
torch
.
mm
(
output
=
(
scale_a
*
(
scale_b
.
T
*
(
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))))).
to
(
out_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
...
...
@@ -99,7 +99,7 @@ def cutlass_int8_gemm_helper(m: int,
scale_a
=
(
torch
.
randn
((
m_a_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
((
1
,
n_b_scales
),
device
=
device
,
scale_b
=
(
torch
.
randn
((
n_b_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
))
if
use_bias
:
...
...
@@ -107,42 +107,48 @@ def cutlass_int8_gemm_helper(m: int,
else
:
bias
=
None
b
=
b
.
contiguous
().
reshape
(
k
,
-
1
)
# print("a.shape:",a.shape)
# print("b.shape:",b.shape)
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
# print("out:",out[0:5][0:5])
# print("baseline:",baseline[0:5][0:5])
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
#
opcheck(torch.ops._C.cutlass_scaled_mm,
#
(out, a, b, scale_a, scale_b, bias))
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
222
,
100
,
33
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
4096
,
8192
,
16384
,
24576
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
use_bias
)
#
@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
#
@pytest.mark.parametrize("n", [2048, 4096, 8192, 16384, 24576, 256, 1024])
#
@pytest.mark.parametrize("k", [128, 496, 1024])
#
@pytest.mark.parametrize("per_act_token", [True, False])
#
@pytest.mark.parametrize("per_out_ch", [True, False])
#
@pytest.mark.parametrize("use_bias", [True, False])
#
@pytest.mark.skipif(not current_platform.has_device_capability(89),
#
reason="FP8 is not supported on this GPU type.")
#
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
#
per_out_ch: bool, use_bias: bool):
#
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
8192
,
16384
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
cutlass_int8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
b
float16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
float16
])
#
torch.
b
float16
,
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
],
...
...
@@ -156,50 +162,50 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
out_dtype
=
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
],
use_bias
:
bool
):
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
,
out_dtype
=
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
,
device
:
str
):
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
,
torch
.
bfloat16
,
device
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_cutlass_int8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
,
device
:
str
):
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
,
out_dtype
=
torch
.
bfloat16
,
device
=
device
)
#
@pytest.mark.parametrize("per_act_token", [True, False])
#
@pytest.mark.parametrize("per_out_ch", [True, False])
#
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
#
@pytest.mark.parametrize("use_bias", [True, False])
#
@pytest.mark.skipif(not current_platform.has_device_capability(89),
#
reason="FP8 is not supported on this GPU type.")
#
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
#
out_dtype: Type[torch.dtype],
#
use_bias: bool):
#
cutlass_fp8_gemm_helper(512,
#
512,
#
512,
#
per_act_token,
#
per_out_ch,
#
use_bias,
#
out_dtype=out_dtype)
#
@pytest.mark.parametrize("per_act_token", [True, False])
#
@pytest.mark.parametrize("per_out_ch", [True, False])
#
@pytest.mark.parametrize("use_bias", [True, False])
#
@pytest.mark.parametrize("device", CUDA_DEVICES)
#
@pytest.mark.skipif(not current_platform.has_device_capability(89),
#
reason="FP8 is not supported on this GPU type.")
#
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
#
use_bias: bool, device: str):
#
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
#
torch.bfloat16, device)
#
@pytest.mark.parametrize("per_act_token", [True, False])
#
@pytest.mark.parametrize("per_out_ch", [True, False])
#
@pytest.mark.parametrize("use_bias", [True, False])
#
@pytest.mark.parametrize("device", CUDA_DEVICES)
#
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
#
use_bias: bool, device: str):
#
cutlass_int8_gemm_helper(512,
#
512,
#
512,
#
per_act_token,
#
per_out_ch,
#
use_bias,
#
out_dtype=torch.bfloat16,
#
device=device)
# For the following two tests:
...
...
@@ -207,155 +213,155 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
# of a large power of two. In any case, the kernel will have a naive fallback
# when N and K are not divisible by 16. But M is the number of tokens and the
# kernel must handle any M thrown at it.
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
for
nk
in
range
(
32
,
128
,
32
):
for
m
in
range
(
1
,
128
):
cutlass_fp8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
for
nk
in
range
(
32
,
128
,
32
):
for
m
in
range
(
1
,
128
):
cutlass_int8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
16
,
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
64
,
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
skip
def
test_cutlass_int8_azp_bias_fold
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
torch
.
dtype
):
# Currently, the test is failing because folding azp into
# 16-bit bias loses too much precision
scale_a
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
scale_b
=
torch
.
randn
((
1
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
aq_i8
=
rand_int8
((
m
,
k
))
bq_i8
=
rand_int8
((
n
,
k
)).
t
()
aq_i32
=
aq_i8
.
to
(
dtype
=
torch
.
int32
)
bq_i32
=
bq_i8
.
to
(
dtype
=
torch
.
int32
)
aq_f32
=
aq_i8
.
to
(
dtype
=
torch
.
float32
)
bq_f32
=
bq_i8
.
to
(
dtype
=
torch
.
float32
)
b_dq
=
scale_b
*
bq_f32
azp_a
=
torch
.
rand
((
1
,
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
10
+
1.5
azp_aq_i8
=
(
azp_a
/
scale_a
).
to
(
dtype
=
torch
.
int8
)
azp_a
=
azp_aq_i8
.
to
(
dtype
=
torch
.
float32
)
*
scale_a
# correct for rounding
a_dq
=
scale_a
*
(
aq_i32
+
azp_aq_i8
).
to
(
dtype
=
torch
.
float32
)
torch
.
testing
.
assert_close
(
a_dq
,
scale_a
*
aq_f32
+
azp_a
)
baseline_dq
=
torch
.
mm
(
a_dq
,
b_dq
).
to
(
out_dtype
)
J
=
torch
.
ones
((
1
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
azp_bias
=
(
azp_a
*
scale_b
*
(
J
@
bq_f32
)).
to
(
out_dtype
)
assert
azp_bias
.
shape
==
(
1
,
n
)
assert
azp_bias
[
0
,
:].
shape
==
(
n
,
)
baseline_q
=
(
scale_a
.
to
(
device
=
'cpu'
)
*
scale_b
.
to
(
device
=
'cpu'
)
*
(
(
aq_i32
+
azp_aq_i8
).
to
(
device
=
'cpu'
)
@
bq_i32
.
to
(
device
=
'cpu'
))).
to
(
dtype
=
out_dtype
,
device
=
'cuda'
)
out
=
ops
.
cutlass_scaled_mm
(
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
out_dtype
=
out_dtype
,
bias
=
azp_bias
[
0
,
:])
torch
.
testing
.
assert_close
(
out
,
baseline_dq
,
rtol
=
1e-2
,
atol
=
1e0
)
torch
.
testing
.
assert_close
(
out
,
baseline_q
,
rtol
=
1e-2
,
atol
=
1e0
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
16
,
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
64
,
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"azp_per_token"
,
[
True
,
False
])
def
test_cutlass_int8_azp
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
torch
.
dtype
,
use_bias
:
bool
,
azp_per_token
:
bool
):
m_azp
=
m
if
azp_per_token
else
1
scale_a
=
torch
.
randn
((
m_azp
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
scale_b
=
torch
.
randn
((
1
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
aq_i8
=
rand_int8
((
m
,
k
))
aq_i32
=
aq_i8
.
to
(
dtype
=
torch
.
int32
)
aq_f32
=
aq_i8
.
to
(
dtype
=
torch
.
float32
)
bq_i8
=
rand_int8
((
n
,
k
)).
t
()
bq_i32
=
bq_i8
.
to
(
dtype
=
torch
.
int32
)
bq_f32
=
bq_i8
.
to
(
dtype
=
torch
.
float32
)
b_dq
=
scale_b
*
bq_f32
azp_a
=
torch
.
rand
(
(
m_azp
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
10
+
1.5
azp_aq_i8
=
(
azp_a
/
scale_a
).
to
(
dtype
=
torch
.
int8
)
azp_a
=
azp_aq_i8
.
to
(
dtype
=
torch
.
float32
)
*
scale_a
# correct for rounding
a_dq
=
scale_a
*
(
aq_i32
-
azp_aq_i8
).
to
(
dtype
=
torch
.
float32
)
torch
.
testing
.
assert_close
(
a_dq
,
scale_a
*
aq_f32
-
azp_a
,
rtol
=
1e-4
,
atol
=
1e-3
)
if
use_bias
:
bias
=
torch
.
rand
((
1
,
n
),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
10
+
2.5
else
:
bias
=
torch
.
zeros
((
1
,
n
),
device
=
"cuda"
,
dtype
=
out_dtype
)
baseline_dq
=
(
torch
.
mm
(
a_dq
,
b_dq
)
+
bias
).
to
(
out_dtype
)
# int32 mm not supported on CUDA
a_noazp_i32_cpu
=
(
aq_i32
-
azp_aq_i8
).
to
(
device
=
'cpu'
)
cq
=
(
a_noazp_i32_cpu
@
bq_i32
.
to
(
device
=
'cpu'
)).
to
(
device
=
'cuda'
)
baseline_q
=
(
scale_a
*
scale_b
*
cq
+
bias
).
to
(
dtype
=
out_dtype
)
# Hadamard is just the sum of the cols
azp_adj_i32
=
bq_i32
.
sum
(
dim
=
0
,
keepdim
=
True
,
dtype
=
torch
.
int32
)
azp_i32
=
azp_aq_i8
.
to
(
dtype
=
torch
.
int32
)
func_bias
=
bias
if
use_bias
else
None
if
azp_per_token
:
out
=
ops
.
cutlass_scaled_mm_azp
(
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
out_dtype
,
azp_adj_i32
,
azp_i32
,
func_bias
)
else
:
azp_with_adj_i32
=
azp_i32
*
azp_adj_i32
out
=
ops
.
cutlass_scaled_mm_azp
(
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
out_dtype
,
azp_with_adj_i32
,
None
,
func_bias
)
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
rtol
=
1e-2
if
out_dtype
==
torch
.
bfloat16
else
1e-3
atol
=
1e-3
torch
.
testing
.
assert_close
(
out
,
baseline_dq
,
rtol
=
rtol
,
atol
=
atol
)
torch
.
testing
.
assert_close
(
out
,
baseline_q
,
rtol
=
rtol
,
atol
=
atol
)
if
azp_per_token
:
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm_azp
,
(
out
,
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
azp_adj_i32
,
azp_i32
,
func_bias
))
else
:
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm_azp
,
(
out
,
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
azp_with_adj_i32
,
None
,
func_bias
))
#
@pytest.mark.parametrize("per_act_token", [True, False])
#
@pytest.mark.parametrize("per_out_ch", [True, False])
#
@pytest.mark.parametrize("use_bias", [True, False])
#
@pytest.mark.skipif(not current_platform.has_device_capability(89),
#
reason="FP8 is not supported on this GPU type.")
#
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
#
use_bias: bool):
#
for nk in range(32, 128, 32):
#
for m in range(1, 128):
#
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
#
use_bias)
#
@pytest.mark.parametrize("per_act_token", [True, False])
#
@pytest.mark.parametrize("per_out_ch", [True, False])
#
@pytest.mark.parametrize("use_bias", [True, False])
#
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
#
use_bias: bool):
#
for nk in range(32, 128, 32):
#
for m in range(1, 128):
#
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
#
use_bias)
#
@pytest.mark.parametrize("m", [32, 64, 128])
#
@pytest.mark.parametrize("n", [16, 32, 64])
#
@pytest.mark.parametrize("k", [64, 128, 256])
#
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
#
@pytest.mark.skip
#
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
#
out_dtype: torch.dtype):
#
# Currently, the test is failing because folding azp into
#
# 16-bit bias loses too much precision
#
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
#
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
#
aq_i8 = rand_int8((m, k))
#
bq_i8 = rand_int8((n, k)).t()
#
aq_i32 = aq_i8.to(dtype=torch.int32)
#
bq_i32 = bq_i8.to(dtype=torch.int32)
#
aq_f32 = aq_i8.to(dtype=torch.float32)
#
bq_f32 = bq_i8.to(dtype=torch.float32)
#
b_dq = scale_b * bq_f32
#
azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
#
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
#
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
#
a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
#
torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
#
baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
#
J = torch.ones((1, k), device="cuda", dtype=torch.float32)
#
azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
#
assert azp_bias.shape == (1, n)
#
assert azp_bias[0, :].shape == (n, )
#
baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
#
(aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
#
dtype=out_dtype, device='cuda')
#
out = ops.cutlass_scaled_mm(aq_i8,
#
bq_i8,
#
scale_a,
#
scale_b,
#
out_dtype=out_dtype,
#
bias=azp_bias[0, :])
#
torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
#
torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
#
@pytest.mark.parametrize("m", [32, 64, 128])
#
@pytest.mark.parametrize("n", [16, 32, 64])
#
@pytest.mark.parametrize("k", [64, 128, 256])
#
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
#
@pytest.mark.parametrize("use_bias", [True, False])
#
@pytest.mark.parametrize("azp_per_token", [True, False])
#
def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
#
use_bias: bool, azp_per_token: bool):
#
m_azp = m if azp_per_token else 1
#
scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
#
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
#
aq_i8 = rand_int8((m, k))
#
aq_i32 = aq_i8.to(dtype=torch.int32)
#
aq_f32 = aq_i8.to(dtype=torch.float32)
#
bq_i8 = rand_int8((n, k)).t()
#
bq_i32 = bq_i8.to(dtype=torch.int32)
#
bq_f32 = bq_i8.to(dtype=torch.float32)
#
b_dq = scale_b * bq_f32
#
azp_a = torch.rand(
#
(m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
#
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
#
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
#
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
#
torch.testing.assert_close(a_dq,
#
scale_a * aq_f32 - azp_a,
#
rtol=1e-4,
#
atol=1e-3)
#
if use_bias:
#
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
#
else:
#
bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)
#
baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)
#
# int32 mm not supported on CUDA
#
a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu')
#
cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda')
#
baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)
#
# Hadamard is just the sum of the cols
#
azp_adj_i32 = bq_i32.sum(dim=0, keepdim=True, dtype=torch.int32)
#
azp_i32 = azp_aq_i8.to(dtype=torch.int32)
#
func_bias = bias if use_bias else None
#
if azp_per_token:
#
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
#
out_dtype, azp_adj_i32, azp_i32,
#
func_bias)
#
else:
#
azp_with_adj_i32 = azp_i32 * azp_adj_i32
#
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
#
out_dtype, azp_with_adj_i32, None,
#
func_bias)
#
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
#
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
#
rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
#
atol = 1e-3
#
torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
#
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
#
if azp_per_token:
#
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
#
(out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
#
func_bias))
#
else:
#
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
#
(out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
#
func_bias))
# Test working with a subset of A and B
...
...
@@ -367,7 +373,11 @@ def test_cutlass_subset():
whole_b
=
to_int8
(
torch
.
randn
((
big_n
,
big_k
),
device
=
"cuda"
).
t
()
*
5
)
a
=
whole_a
[
0
:
m
,
0
:
k
]
b
=
whole_b
[
0
:
k
,
0
:
n
]
#变成连续内存,矩阵子模块目前不支持计算,需要重新计算lda
a
=
a
.
contiguous
().
reshape
(
m
,
-
1
)
b
=
b
.
contiguous
().
reshape
(
k
,
-
1
)
scale_a
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
scale_b
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
...
...
@@ -399,25 +409,26 @@ class CutlassLayer(torch.nn.Module):
return
ops
.
cutlass_scaled_mm
(
a
,
self
.
b
,
self
.
scale_a
,
self
.
scale_b
,
self
.
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
#目前只支持per-act-token+per-out-ch(fp16)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
])
def
test_cutlass_cuda_graph
(
per_act_token
:
bool
,
per_out_ch
:
bool
):
m
,
n
,
k
=
512
,
512
,
512
a
=
to_int8
(
torch
.
randn
((
m
,
k
),
device
=
"cuda"
))
b
=
to_int8
(
torch
.
randn
((
n
,
k
),
device
=
"cuda"
).
t
())
b
=
b
.
contiguous
().
reshape
(
k
,
-
1
)
m_a_scales
=
m
if
per_act_token
else
1
n_b_scales
=
n
if
per_out_ch
else
1
scale_a
=
(
torch
.
randn
(
(
m_a_scales
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
)
scale_b
=
(
torch
.
randn
(
(
1
,
n_b_scales
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
)
(
n_b_scales
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
)
# Construct a trivial model with a single layer that calls a CUTLASS kernel
model
=
CutlassLayer
(
b
,
scale_a
,
scale_b
,
torch
.
b
float16
)
model
=
CutlassLayer
(
b
,
scale_a
,
scale_b
,
torch
.
float16
)
# Run the model with a cuda graph
stream
=
torch
.
cuda
.
Stream
()
...
...
@@ -429,9 +440,9 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
g
.
replay
()
baseline
=
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
torch
.
bfloat16
)
scale_b
.
T
*
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
torch
.
float16
)
#print("baseline:",baseline)
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
)
#print("out:",out)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
def
test_cutlass_support_opcheck
():
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm_supports_fp8
,
(
capability
,
))
vllm/_custom_ops.py
View file @
26513bb5
...
...
@@ -706,7 +706,8 @@ def cutlass_scaled_mm(a: torch.Tensor,
# torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
# return out
return
quant_ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
#return quant_ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
rocblas_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment