Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e23564cb
Unverified
Commit
e23564cb
authored
May 16, 2025
by
Lain
Committed by
GitHub
May 16, 2025
Browse files
use ceil_div in cutlass block scaling shape check (#17918)
parent
390ec889
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
25 deletions
+62
-25
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
+10
-2
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
+39
-21
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+13
-2
No files found.
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
View file @
e23564cb
...
...
@@ -115,8 +115,16 @@ def bench_fp8(
a_cont
=
a
.
contiguous
()
scale_a
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
block_scale_a
=
torch
.
rand
((
m
,
k
//
128
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
block_scale_b
=
torch
.
rand
((
k
//
128
,
n
//
128
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
def
ceil_div
(
x
:
int
,
y
:
int
)
->
int
:
return
(
x
+
y
-
1
)
//
y
block_scale_a
=
torch
.
rand
(
(
m
,
ceil_div
(
k
,
128
)),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
block_scale_b
=
torch
.
rand
(
ceil_div
(
k
,
128
),
ceil_div
(
n
,
128
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
block_scale_a_M_major
=
block_scale_a
.
t
().
contiguous
().
t
()
block_scale_b_K_major
=
block_scale_b
.
t
().
contiguous
().
t
()
bias
=
torch
.
zeros
((
n
,),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
...
...
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
View file @
e23564cb
#include <torch/all.h>
#include "cuda_utils.h"
#include "cutlass_extensions/common.hpp"
template
<
typename
Fp8Func
,
typename
Int8Func
,
typename
BlockwiseFunc
>
void
dispatch_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
...
...
@@ -28,6 +29,21 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
}
}
}
else
{
TORCH_CHECK
(
a_scales
.
dim
()
==
2
,
"a scale must be 2d tensor."
);
TORCH_CHECK
(
b_scales
.
dim
()
==
2
,
"b scale must be 2d tensor."
);
int32_t
version_num
=
get_sm_version_num
();
if
(
version_num
>=
100
)
{
TORCH_CHECK
(
a
.
size
(
0
)
==
a_scales
.
size
(
0
)
&&
cuda_utils
::
ceil_div
(
a
.
size
(
1
),
int64_t
(
128
))
==
a_scales
.
size
(
1
),
"a_scale_group_shape must be [1, 128]."
);
TORCH_CHECK
(
cuda_utils
::
ceil_div
(
b
.
size
(
0
),
int64_t
(
128
))
==
b_scales
.
size
(
0
)
&&
cuda_utils
::
ceil_div
(
b
.
size
(
1
),
int64_t
(
128
))
==
b_scales
.
size
(
1
),
"b_scale_group_shape must be [128, 128]."
);
}
else
{
// TODO: Remove this after using cutlass sm90 blockwise scaling gemm
// kernel, or introducing ceil_div to the load_init() of mainloop.
using
GroupShape
=
std
::
array
<
int64_t
,
2
>
;
auto
make_group_shape
=
[](
torch
::
Tensor
const
&
x
,
torch
::
Tensor
const
&
s
)
->
GroupShape
{
...
...
@@ -51,6 +67,8 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
"]
\n
"
"b_scale_group_shape must be [128, 128]. Got: ["
,
b_scale_group_shape
[
0
],
", "
,
b_scale_group_shape
[
1
],
"]"
);
}
TORCH_CHECK
(
!
bias
,
"Bias not yet supported blockwise scaled_mm"
);
blockwise_func
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
e23564cb
...
...
@@ -115,6 +115,17 @@ def apply_w8a8_block_fp8_linear(
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
if
current_platform
.
is_cuda
():
if
current_platform
.
has_device_capability
(
100
):
def
ceil_div
(
x
:
int
,
y
:
int
)
->
int
:
return
(
x
+
y
-
1
)
//
y
use_cutlass
=
cutlass_block_fp8_supported
and
(
ceil_div
(
weight
.
shape
[
0
],
128
)
==
weight_scale
.
shape
[
0
]
and
ceil_div
(
weight
.
shape
[
1
],
128
)
==
weight_scale
.
shape
[
1
])
else
:
# TODO: update this after switching to public sm90 block scale gemm
# as it also supports weight.shape % 128 != 0
use_cutlass
=
cutlass_block_fp8_supported
and
(
weight
.
shape
[
0
]
%
128
==
0
and
weight
.
shape
[
1
]
%
128
==
0
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment