Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9798b2fb
Unverified
Commit
9798b2fb
authored
Jan 30, 2025
by
Lucas Wilkinson
Committed by
GitHub
Jan 30, 2025
Browse files
[Kernel] Update `cutlass_scaled_mm` to support 2d group (blockwise) scaling (#11868)
parent
4078052f
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
188 additions
and
61 deletions
+188
-61
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+0
-3
csrc/quantization/machete/machete_mainloop.cuh
csrc/quantization/machete/machete_mainloop.cuh
+4
-0
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+132
-56
tests/kernels/utils.py
tests/kernels/utils.py
+30
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+22
-0
No files found.
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
9798b2fb
...
@@ -89,15 +89,12 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
...
@@ -89,15 +89,12 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
b
.
size
(
1
)
==
c
.
size
(
1
));
b
.
size
(
1
)
==
c
.
size
(
1
));
TORCH_CHECK
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
));
TORCH_CHECK
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
));
// Check for strides and alignment
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
c
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
c
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
b
.
stride
(
0
)
==
1
);
// Column-major
TORCH_CHECK
(
b
.
stride
(
0
)
==
1
);
// Column-major
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
&&
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
&&
b
.
stride
(
1
)
%
16
==
0
);
// 16 Byte Alignment
b
.
stride
(
1
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
if
(
bias
)
{
TORCH_CHECK
(
bias
->
numel
()
==
b
.
size
(
1
)
&&
bias
->
is_contiguous
()
&&
TORCH_CHECK
(
bias
->
numel
()
==
b
.
size
(
1
)
&&
bias
->
is_contiguous
()
&&
...
...
csrc/quantization/machete/machete_mainloop.cuh
View file @
9798b2fb
...
@@ -272,6 +272,10 @@ struct MacheteCollectiveMma {
...
@@ -272,6 +272,10 @@ struct MacheteCollectiveMma {
using
PipelineState
=
cutlass
::
PipelineState
<
DispatchPolicy
::
Stages
>
;
using
PipelineState
=
cutlass
::
PipelineState
<
DispatchPolicy
::
Stages
>
;
using
PipelineParams
=
typename
MainloopPipeline
::
Params
;
using
PipelineParams
=
typename
MainloopPipeline
::
Params
;
// One threads per CTA are producers (1 for operand tile)
static
constexpr
int
NumProducerThreadEvents
=
1
;
using
ScaleTileShape
=
decltype
(
make_shape
(
shape
<
0
>
(
TileShape
{}),
using
ScaleTileShape
=
decltype
(
make_shape
(
shape
<
0
>
(
TileShape
{}),
shape
<
1
>
(
SmemLayoutAtomScale
{})));
shape
<
1
>
(
SmemLayoutAtomScale
{})));
...
...
tests/kernels/test_cutlass.py
View file @
9798b2fb
...
@@ -10,6 +10,7 @@ import torch
...
@@ -10,6 +10,7 @@ import torch
from
tests.kernels.utils
import
opcheck
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
.utils
import
baseline_scaled_mm
,
to_fp8
,
to_int8
from
.utils
import
baseline_scaled_mm
,
to_fp8
,
to_int8
...
@@ -39,6 +40,11 @@ CUDA_DEVICES = [
...
@@ -39,6 +40,11 @@ CUDA_DEVICES = [
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
]
# -1 means full extent in that dimension
TENSORWISE_GROUP_SHAPE
=
(
-
1
,
-
1
)
PER_TOKEN_GROUP_SHAPE
=
(
1
,
-
1
)
PER_OUT_CH_GROUP_SHAPE
=
(
-
1
,
1
)
capability
=
current_platform
.
get_device_capability
()
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
...
@@ -47,11 +53,22 @@ def rand_int8(shape: tuple, device: str = "cuda"):
...
@@ -47,11 +53,22 @@ def rand_int8(shape: tuple, device: str = "cuda"):
return
to_int8
(
torch
.
rand
(
shape
,
device
=
device
)
*
255
-
128
)
return
to_int8
(
torch
.
rand
(
shape
,
device
=
device
)
*
255
-
128
)
def
group_scale_helper
(
shape
,
group_shape
):
return
[
shape
[
i
]
if
s
<
0
else
s
for
i
,
s
in
enumerate
(
group_shape
)]
def
scale_shape
(
shape
,
group_shape
):
assert
len
(
shape
)
==
len
(
group_shape
)
group_shape
=
group_scale_helper
(
shape
,
group_shape
)
return
tuple
(
cdiv
(
shape
[
i
],
group_shape
[
i
])
for
i
in
range
(
len
(
group_shape
)))
def
cutlass_fp8_gemm_helper
(
m
:
int
,
def
cutlass_fp8_gemm_helper
(
m
:
int
,
n
:
int
,
n
:
int
,
k
:
int
,
k
:
int
,
per_token_act_quant
:
bool
,
a_scale_group_shape
:
tuple
,
per_out_channel_weight_quant
:
bool
,
b_scale_group_shape
:
tuple
,
use_bias
:
bool
,
use_bias
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
device
:
str
=
"cuda"
):
...
@@ -60,13 +77,17 @@ def cutlass_fp8_gemm_helper(m: int,
...
@@ -60,13 +77,17 @@ def cutlass_fp8_gemm_helper(m: int,
a
=
to_fp8
(
torch
.
randn
((
m
,
k
),
device
=
device
))
a
=
to_fp8
(
torch
.
randn
((
m
,
k
),
device
=
device
))
b
=
to_fp8
(
torch
.
randn
((
n
,
k
),
device
=
device
).
t
())
b
=
to_fp8
(
torch
.
randn
((
n
,
k
),
device
=
device
).
t
())
m_a_scales
=
m
if
per_token_act_quant
else
1
a_scales_shape
=
scale_shape
(
a
.
shape
,
a_scale_group_shape
)
n_b_scales
=
n
if
per_out_channel_weight_quant
else
1
b_scales_shape
=
scale_shape
(
b
.
shape
,
b_scale_group_shape
)
scale_a
=
(
torch
.
randn
(
a_scales_shape
,
device
=
device
,
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
(
b_scales_shape
,
device
=
device
,
dtype
=
torch
.
float32
))
# make scales M-major for blockwise quant, doesn't affect 1D scales
scale_a
=
scale_a
.
t
().
contiguous
().
t
()
# make scales K-major for blockwise quant, doesn't affect 1D scales
scale_b
=
scale_b
.
t
().
contiguous
().
t
()
scale_a
=
(
torch
.
randn
((
m_a_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
((
1
,
n_b_scales
),
device
=
device
,
dtype
=
torch
.
float32
))
if
use_bias
:
if
use_bias
:
bias
=
torch
.
rand
((
n
,
),
device
=
device
,
dtype
=
out_dtype
)
*
10
bias
=
torch
.
rand
((
n
,
),
device
=
device
,
dtype
=
out_dtype
)
*
10
else
:
else
:
...
@@ -84,8 +105,8 @@ def cutlass_fp8_gemm_helper(m: int,
...
@@ -84,8 +105,8 @@ def cutlass_fp8_gemm_helper(m: int,
def
cutlass_int8_gemm_helper
(
m
:
int
,
def
cutlass_int8_gemm_helper
(
m
:
int
,
n
:
int
,
n
:
int
,
k
:
int
,
k
:
int
,
per_token_act_quant
:
bool
,
a_scale_group_shape
:
tuple
,
per_out_channel_weight_quant
:
bool
,
b_scale_group_shape
:
tuple
,
use_bias
:
bool
,
use_bias
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
device
:
str
=
"cuda"
):
...
@@ -94,13 +115,11 @@ def cutlass_int8_gemm_helper(m: int,
...
@@ -94,13 +115,11 @@ def cutlass_int8_gemm_helper(m: int,
a
=
to_int8
(
torch
.
randn
((
m
,
k
),
device
=
device
)
*
5
)
a
=
to_int8
(
torch
.
randn
((
m
,
k
),
device
=
device
)
*
5
)
b
=
to_int8
(
torch
.
randn
((
n
,
k
),
device
=
device
).
t
()
*
5
)
b
=
to_int8
(
torch
.
randn
((
n
,
k
),
device
=
device
).
t
()
*
5
)
m_
a_scales
=
m
if
per_token_act_quant
else
1
a_scales
_shape
=
scale_shape
(
a
.
shape
,
a_scale_group_shape
)
n_
b_scales
=
n
if
per_out_channel_weight_quant
else
1
b_scales
_shape
=
scale_shape
(
b
.
shape
,
b_scale_group_shape
)
scale_a
=
(
torch
.
randn
((
m_a_scales
,
1
),
device
=
device
,
scale_a
=
(
torch
.
randn
(
a_scales_shape
,
device
=
device
,
dtype
=
torch
.
float32
))
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
(
b_scales_shape
,
device
=
device
,
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
((
1
,
n_b_scales
),
device
=
device
,
dtype
=
torch
.
float32
))
if
use_bias
:
if
use_bias
:
bias
=
torch
.
rand
((
n
,
),
device
=
device
,
dtype
=
out_dtype
)
*
10
bias
=
torch
.
rand
((
n
,
),
device
=
device
,
dtype
=
out_dtype
)
*
10
...
@@ -117,82 +136,135 @@ def cutlass_int8_gemm_helper(m: int,
...
@@ -117,82 +136,135 @@ def cutlass_int8_gemm_helper(m: int,
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"a_scale_group_shape"
,
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
[
PER_TOKEN_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"b_scale_group_shape"
,
[
PER_OUT_CH_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
reason
=
"FP8 is not supported on this GPU type."
)
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
def
test_cutlass_fp8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
a_scale_group_shape
,
per_out_ch
:
bool
,
use_bias
:
bool
):
b_scale_group_shape
,
use_bias
:
bool
):
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
use_bias
)
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"a_scale_group_shape,b_scale_group_shape"
,
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
[((
1
,
128
),
(
128
,
128
))])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
False
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
90
),
reason
=
"FP8 blockwise is not supported on this GPU type."
)
def
test_cutlass_fp8_blockwise_scale_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
:
bool
):
if
k
%
b_scale_group_shape
[
0
]
!=
0
or
n
%
b_scale_group_shape
[
1
]
!=
0
:
return
if
m
%
a_scale_group_shape
[
0
]
!=
0
or
k
%
a_scale_group_shape
[
1
]
!=
0
:
return
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"a_scale_group_shape"
,
[
PER_TOKEN_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"b_scale_group_shape"
,
[
PER_OUT_CH_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
def
test_cutlass_int8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
a_scale_group_shape
,
per_out_ch
:
bool
,
use_bias
:
bool
):
b_scale_group_shape
,
use_bias
:
bool
):
cutlass_int8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
use_bias
)
cutlass_int8_gemm_helper
(
m
,
n
,
k
,
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"a_scale_group_shape"
,
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
[
PER_TOKEN_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"b_scale_group_shape"
,
[
PER_OUT_CH_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_int8_gemm_output_dtype
(
a_scale_group_shape
,
b_scale_group_shape
,
out_dtype
:
Type
[
torch
.
dtype
],
out_dtype
:
Type
[
torch
.
dtype
],
use_bias
:
bool
):
use_bias
:
bool
):
cutlass_int8_gemm_helper
(
512
,
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
512
,
512
,
per_act_token
,
a_scale_group_shape
,
per_out_ch
,
b_scale_group_shape
,
use_bias
,
use_bias
,
out_dtype
=
out_dtype
)
out_dtype
=
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"a_scale_group_shape"
,
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
[
PER_TOKEN_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"b_scale_group_shape"
,
[
PER_OUT_CH_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
reason
=
"FP8 is not supported on this GPU type."
)
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_fp8_gemm_output_dtype
(
a_scale_group_shape
,
b_scale_group_shape
,
out_dtype
:
Type
[
torch
.
dtype
],
out_dtype
:
Type
[
torch
.
dtype
],
use_bias
:
bool
):
use_bias
:
bool
):
cutlass_fp8_gemm_helper
(
512
,
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
512
,
512
,
per_act_token
,
a_scale_group_shape
,
per_out_ch
,
b_scale_group_shape
,
use_bias
,
use_bias
,
out_dtype
=
out_dtype
)
out_dtype
=
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"a_scale_group_shape,b_scale_group_shape"
,
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
[((
1
,
128
),
(
128
,
128
))])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
False
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
90
),
reason
=
"FP8 blockwise is not supported on this GPU type."
)
def
test_cutlass_fp8_blockwise_scale_gemm_dtype
(
a_scale_group_shape
,
b_scale_group_shape
,
out_dtype
:
Type
[
torch
.
dtype
],
use_bias
:
bool
):
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
,
out_dtype
=
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"a_scale_group_shape"
,
[
PER_TOKEN_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"b_scale_group_shape"
,
[
PER_OUT_CH_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
reason
=
"FP8 is not supported on this GPU type."
)
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_fp8_gemm_devices
(
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
:
bool
,
device
:
str
):
use_bias
:
bool
,
device
:
str
):
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
,
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
a_scale_group_shape
,
torch
.
bfloat16
,
device
)
b_scale_group_shape
,
use_bias
,
torch
.
bfloat16
,
device
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"a_scale_group_shape"
,
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
[
PER_TOKEN_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"b_scale_group_shape"
,
[
PER_OUT_CH_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_cutlass_int8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_int8_gemm_devices
(
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
:
bool
,
device
:
str
):
use_bias
:
bool
,
device
:
str
):
cutlass_int8_gemm_helper
(
512
,
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
512
,
512
,
per_act_token
,
a_scale_group_shape
,
per_out_ch
,
b_scale_group_shape
,
use_bias
,
use_bias
,
out_dtype
=
torch
.
bfloat16
,
out_dtype
=
torch
.
bfloat16
,
device
=
device
)
device
=
device
)
...
@@ -203,28 +275,32 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
...
@@ -203,28 +275,32 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
# of a large power of two. In any case, the kernel will have a naive fallback
# of a large power of two. In any case, the kernel will have a naive fallback
# when N and K are not divisible by 16. But M is the number of tokens and the
# when N and K are not divisible by 16. But M is the number of tokens and the
# kernel must handle any M thrown at it.
# kernel must handle any M thrown at it.
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"a_scale_group_shape"
,
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
[
PER_TOKEN_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"b_scale_group_shape"
,
[
PER_OUT_CH_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
reason
=
"FP8 is not supported on this GPU type."
)
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_fp8_gemm_m_sweep
(
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
:
bool
):
use_bias
:
bool
):
for
nk
in
range
(
32
,
128
,
32
):
for
nk
in
range
(
32
,
128
,
32
):
for
m
in
range
(
1
,
128
):
for
m
in
range
(
1
,
128
):
cutlass_fp8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
,
cutlass_fp8_gemm_helper
(
m
,
nk
,
nk
,
a_scale_group_shape
,
use_bias
)
b_scale_group_shape
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"a_scale_group_shape"
,
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
[
PER_TOKEN_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"b_scale_group_shape"
,
[
PER_OUT_CH_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_int8_gemm_m_sweep
(
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
:
bool
):
use_bias
:
bool
):
for
nk
in
range
(
32
,
128
,
32
):
for
nk
in
range
(
32
,
128
,
32
):
for
m
in
range
(
1
,
128
):
for
m
in
range
(
1
,
128
):
cutlass_int8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
,
cutlass_int8_gemm_helper
(
m
,
nk
,
nk
,
a_scale_group_shape
,
use_bias
)
b_scale_group_shape
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
32
,
64
,
128
])
...
...
tests/kernels/utils.py
View file @
9798b2fb
...
@@ -1119,8 +1119,36 @@ def baseline_scaled_mm(a: torch.Tensor,
...
@@ -1119,8 +1119,36 @@ def baseline_scaled_mm(a: torch.Tensor,
scale_b
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
],
out_dtype
:
Type
[
torch
.
dtype
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
output
=
(
scale_a
*
(
scale_b
*
(
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))))).
to
(
out_dtype
)
# We treat N-dimensional group scaling as extended numpy-style broadcasting
# in numpy simply stretches dimensions with an extent of 1 to match the
# the target shape by repeating the data along that dimension (broadcasting)
# , we extend these semantics to say if the extent of a dimension in the
# source shape is not 1 and does not match the target shape we repeat each
# element along that dimension src_shape[dim] // target_shape[dim] times
# example if we have:
# a = [[1, 2], and target_shape = (2, 4)
# [3, 4]]
# then we would expand a to:
# a = [[1, 1, 2, 2],
# [3, 3, 4, 4]]
# NOTE this function this function does not explicitly broadcast dimensions
# with an extent of 1, since this can be done implicitly by pytorch
def
group_broadcast
(
t
,
shape
):
for
i
,
s
in
enumerate
(
shape
):
if
t
.
shape
[
i
]
!=
s
and
t
.
shape
[
i
]
!=
1
:
assert
s
%
t
.
shape
[
i
]
==
0
t
=
t
.
unsqueeze
(
i
+
1
)
\
.
expand
(
*
t
.
shape
[:
i
+
1
],
s
//
t
.
shape
[
i
],
*
t
.
shape
[
i
+
1
:])
\
.
flatten
(
i
,
i
+
1
)
return
t
scale_a
=
group_broadcast
(
scale_a
,
a
.
shape
)
scale_b
=
group_broadcast
(
scale_b
,
b
.
shape
)
output
=
torch
.
mm
((
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
)),
(
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
))).
to
(
out_dtype
)
if
bias
is
not
None
:
if
bias
is
not
None
:
output
=
output
+
bias
output
=
output
+
bias
...
...
vllm/_custom_ops.py
View file @
9798b2fb
...
@@ -441,6 +441,28 @@ def cutlass_scaled_mm(a: torch.Tensor,
...
@@ -441,6 +441,28 @@ def cutlass_scaled_mm(a: torch.Tensor,
scale_b
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
`cutlass_scaled_mm` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
In order to support blockwise scaling like found in DeepSeek V3 we also
support extended "group" broadcast rules. We extend the numpy-style
broadcasting rules with the following rule:
"if the extent of a dimension in the source shape is between 1 and
corresponding extent in the target shape we repeat each element along
that dimension src_shape[dim] // target_shape[dim] times consecutively"
example if we have:
a = [[1, 2], and target_shape = (2, 4)
[3, 4]]
then we would expand a to:
a = [[1, 1, 2, 2],
[3, 3, 4, 4]]
currently we only support the case:
scale_a.shape * [1, 128] == a.shape
scale_b.shape * [128, 128] == b.shape
"""
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
bias
is
None
or
bias
.
shape
[
0
]
==
b
.
shape
[
assert
bias
is
None
or
bias
.
shape
[
0
]
==
b
.
shape
[
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment