Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5f2cd251
Unverified
Commit
5f2cd251
authored
Jun 04, 2025
by
Lain
Committed by
GitHub
Jun 04, 2025
Browse files
Sm100 blockwise fp8 swap ab (#18564)
parent
02658c2d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
140 additions
and
84 deletions
+140
-84
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
...ization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
+0
-4
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
...tlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
+140
-66
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+0
-14
No files found.
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
View file @
5f2cd251
...
...
@@ -9,10 +9,6 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
TORCH_CHECK
(
a
.
size
(
0
)
%
4
==
0
,
"Input tensor must have a number of rows that is a multiple of 4. "
,
"but got: "
,
a
.
size
(
0
),
" rows."
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
cutlass_gemm_blockwise_sm100_fp8_dispatch
<
cutlass
::
bfloat16_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
...
...
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
View file @
5f2cd251
#pragma once
#include "cuda_utils.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
...
...
@@ -22,49 +23,49 @@ namespace vllm {
using
namespace
cute
;
template
<
typename
OutType
,
typename
MmaTileShape
,
typename
ScalesPerTile
,
class
ClusterShape
,
typename
EpilogueScheduler
,
typename
MainloopScheduler
>
// clang-format off
template
<
class
OutType
,
int
ScaleGranularityM
,
int
ScaleGranularityN
,
int
ScaleGranularityK
,
class
MmaTileShape
,
class
ClusterShape
,
class
EpilogueScheduler
,
class
MainloopScheduler
,
bool
swap_ab_
=
false
>
struct
cutlass_3x_gemm_fp8_blockwise
{
static
constexpr
bool
swap_ab
=
swap_ab_
;
using
ElementAB
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
ElementAB
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutA_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutA
>::
type
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
using
ElementB
=
ElementAB
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutB_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutB
>::
type
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
using
ElementC
=
void
;
using
ElementD
=
OutType
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
using
LayoutD_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutD
>::
type
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
using
ElementC
=
void
;
// TODO: support bias
using
LayoutC
=
LayoutD
;
using
LayoutC_Transpose
=
LayoutD_Transpose
;
static
constexpr
int
AlignmentC
=
AlignmentD
;
using
ElementAccumulator
=
float
;
using
ElementCompute
=
float
;
using
ElementBlockScale
=
float
;
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster
// Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>;
static
constexpr
int
ScaleMsPerTile
=
size
<
0
>
(
ScalesPerTile
{});
static
constexpr
int
ScaleGranularityM
=
size
<
0
>
(
MmaTileShape
{})
/
ScaleMsPerTile
;
static
constexpr
int
ScaleGranularityN
=
size
<
1
>
(
MmaTileShape
{})
/
size
<
1
>
(
ScalesPerTile
{});
static
constexpr
int
ScaleGranularityK
=
size
<
2
>
(
MmaTileShape
{})
/
size
<
2
>
(
ScalesPerTile
{});
// Shape of the threadblocks in a cluster
using
ClusterShape_MNK
=
ClusterShape
;
using
ScaleConfig
=
cutlass
::
detail
::
Sm100BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
cute
::
UMMA
::
Major
::
MN
,
cute
::
UMMA
::
Major
::
K
>
;
using
ScaleConfig
=
conditional_t
<
swap_ab
,
cutlass
::
detail
::
Sm100BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
cute
::
UMMA
::
Major
::
K
,
cute
::
UMMA
::
Major
::
MN
>
,
cutlass
::
detail
::
Sm100BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
cute
::
UMMA
::
Major
::
MN
,
cute
::
UMMA
::
Major
::
K
>>
;
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
...
...
@@ -73,7 +74,6 @@ struct cutlass_3x_gemm_fp8_blockwise {
static
constexpr
auto
RoundStyle
=
cutlass
::
FloatRoundStyle
::
round_to_nearest
;
using
ElementScalar
=
float
;
// clang-format off
using
DefaultOperation
=
cutlass
::
epilogue
::
fusion
::
LinearCombination
<
ElementD
,
ElementCompute
,
ElementC
,
ElementScalar
,
RoundStyle
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
...
...
@@ -84,33 +84,47 @@ struct cutlass_3x_gemm_fp8_blockwise {
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
conditional_t
<
swap_ab
,
LayoutC_Transpose
,
LayoutC
>
,
AlignmentC
,
ElementD
,
LayoutD
,
conditional_t
<
swap_ab
,
LayoutD_Transpose
,
LayoutD
>
,
AlignmentD
,
EpilogueScheduler
,
DefaultOperation
>::
CollectiveOp
;
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
cute
::
tuple
<
LayoutA
,
LayoutSFA
>
,
AlignmentA
,
Ele
mentB
,
cute
::
tuple
<
LayoutB
,
LayoutSFB
>
,
AlignmentB
,
ElementAccumulator
,
MmaTileShape
,
Cluster
Shape
,
using
CollectiveMainloop
=
conditional_t
<
swap_ab
,
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementB
,
cute
::
tuple
<
LayoutB_Transpose
,
LayoutSFA
>
,
Align
mentB
,
ElementA
,
cute
::
tuple
<
LayoutA_Transpose
,
LayoutSFB
>
,
AlignmentA
,
ElementAccumulator
,
MmaTile
Shape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
MainloopScheduler
>::
CollectiveOp
;
// clang-format on
MainloopScheduler
>::
CollectiveOp
,
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
cute
::
tuple
<
LayoutA
,
LayoutSFA
>
,
AlignmentA
,
ElementB
,
cute
::
tuple
<
LayoutB
,
LayoutSFB
>
,
AlignmentB
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
MainloopScheduler
>::
CollectiveOp
>
;
using
KernelType
=
enable_sm100_only
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
>>
;
...
...
@@ -123,6 +137,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
static
constexpr
bool
swap_ab
=
Gemm
::
swap_ab
;
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
...
...
@@ -136,7 +151,6 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
);
auto
prob_shape
=
cute
::
make_shape
(
m
,
n
,
k
,
1
);
StrideA
a_stride
;
StrideB
b_stride
;
...
...
@@ -146,11 +160,13 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
b_stride
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
n
,
k
,
1
));
c_stride
=
cutlass
::
make_cute_packed_stride
(
StrideC
{},
cute
::
make_shape
(
m
,
n
,
1
));
cutlass
::
make_cute_packed_stride
(
StrideC
{},
swap_ab
?
cute
::
make_shape
(
n
,
m
,
1
)
:
cute
::
make_shape
(
m
,
n
,
1
));
LayoutSFA
layout_SFA
=
LayoutSFA
layout_SFA
=
swap_ab
?
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
n
,
m
,
k
,
1
))
:
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
m
,
n
,
k
,
1
));
LayoutSFB
layout_SFB
=
LayoutSFB
layout_SFB
=
swap_ab
?
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
n
,
m
,
k
,
1
))
:
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
...
...
@@ -158,9 +174,22 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
auto
a_scales_ptr
=
static_cast
<
float
*>
(
a_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
float
*>
(
b_scales
.
data_ptr
());
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
a_scales_ptr
,
layout_SFA
,
b_scales_ptr
,
layout_SFB
};
auto
mainloop_args
=
[
&
](){
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
if
(
swap_ab
)
{
return
typename
GemmKernel
::
MainloopArguments
{
b_ptr
,
b_stride
,
a_ptr
,
a_stride
,
b_scales_ptr
,
layout_SFA
,
a_scales_ptr
,
layout_SFB
};
}
else
{
return
typename
GemmKernel
::
MainloopArguments
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
a_scales_ptr
,
layout_SFA
,
b_scales_ptr
,
layout_SFB
};
}
}();
auto
prob_shape
=
swap_ab
?
cute
::
make_shape
(
n
,
m
,
k
,
1
)
:
cute
::
make_shape
(
m
,
n
,
k
,
1
);
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
...
...
@@ -175,29 +204,74 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
auto
m
=
a
.
size
(
0
);
auto
k
=
a
.
size
(
1
);
auto
n
=
b
.
size
(
1
);
int
sms
;
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
),
sms
;
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
a
.
get_device
());
auto
should_use_2sm
=
[
&
sms
](
int
m
,
int
n
,
int
tile1SM
=
128
)
{
return
std
::
ceil
(
static_cast
<
float
>
(
m
)
/
tile1SM
)
*
std
::
ceil
(
static_cast
<
float
>
(
n
)
/
tile1SM
)
>=
sms
;
};
bool
use_2sm
=
should_use_2sm
(
m
,
n
);
if
(
use_2sm
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
Shape
<
_256
,
_128
,
_128
>
,
Shape
<
_256
,
_1
,
_1
>
,
Shape
<
_2
,
_2
,
_1
>
,
cutlass
::
epilogue
::
TmaWarpSpecialized2Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise2SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
constexpr
int
TILE_K
=
128
;
// TODO: better heuristics
bool
swap_ab
=
(
m
<
16
)
||
(
m
%
4
!=
0
);
bool
use_tma_epilogue
=
(
m
*
n
)
%
4
==
0
;
if
(
!
swap_ab
)
{
constexpr
int
TILE_N
=
128
;
int
tile_m
=
256
;
if
(
cuda_utils
::
ceil_div
(
n
,
TILE_N
)
*
cuda_utils
::
ceil_div
(
m
,
64
)
<=
sms
)
{
tile_m
=
64
;
}
else
if
(
cuda_utils
::
ceil_div
(
n
,
TILE_N
)
*
cuda_utils
::
ceil_div
(
m
,
128
)
<=
sms
)
{
tile_m
=
128
;
}
if
(
tile_m
==
64
)
{
if
(
use_tma_epilogue
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_64
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
TmaWarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_64
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
NoSmemWarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
else
if
(
tile_m
==
128
)
{
if
(
use_tma_epilogue
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_128
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
TmaWarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_128
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
NoSmemWarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
else
{
// tile_m == 256
if
(
use_tma_epilogue
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_256
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_2
,
_1
,
_1
>
,
cutlass
::
epilogue
::
TmaWarpSpecialized2Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise2SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_256
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_2
,
_1
,
_1
>
,
cutlass
::
epilogue
::
NoSmemWarpSpecialized2Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise2SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
else
{
// TODO: Test more tile N configs
constexpr
int
TILE_M
=
128
;
constexpr
int
TILE_N
=
16
;
// TMA epilogue isn't compatible with Swap A/B
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
Shape
<
_128
,
_128
,
_128
>
,
Shape
<
_128
,
_1
,
_1
>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
Tma
WarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
>>
(
OutType
,
TILE_M
,
1
,
TILE_K
,
Shape
<
Int
<
TILE_M
>
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>
>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
NoSmem
WarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
,
true
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
5f2cd251
...
...
@@ -136,24 +136,10 @@ def apply_w8a8_block_fp8_linear(
use_cutlass
,
use_aiter_and_is_supported
)
if
use_cutlass
:
rows
,
cols
=
input_2d
.
shape
# Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for
# optimal tensor core usage. Can be removed when targeting platforms
# without this constraint.
should_pad
=
current_platform
.
has_device_capability
(
100
)
and
rows
%
4
!=
0
if
should_pad
:
input_2d
=
torch
.
nn
.
functional
.
pad
(
input_2d
,
(
0
,
0
,
0
,
4
-
(
rows
%
4
)),
value
=
0
).
contiguous
()
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
use_cutlass
)
output
=
w8a8_blockscale_func
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
input
.
dtype
)
if
should_pad
:
output
=
output
[:
rows
,
:]
else
:
q_input
,
x_scale
=
per_token_group_quant_fp8
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment