Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ebf495f0
Unverified
Commit
ebf495f0
authored
Apr 10, 2025
by
Yi Zhang
Committed by
GitHub
Apr 09, 2025
Browse files
sgl-kernel use cutlass latest version for fp8 blockwise gemm (#5207)
parent
7f875f12
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
86 additions
and
923 deletions
+86
-923
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
+49
-16
sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
...cutlass_extensions/gemm/collective/collective_builder.hpp
+0
-125
sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
...mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
+0
-733
sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
+0
-37
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
+34
-9
sgl-kernel/tests/test_fp8_blockwise_gemm.py
sgl-kernel/tests/test_fp8_blockwise_gemm.py
+3
-3
No files found.
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
View file @
ebf495f0
...
@@ -2,18 +2,22 @@ import argparse
...
@@ -2,18 +2,22 @@ import argparse
import
copy
import
copy
import
itertools
import
itertools
import
deep_gemm
import
torch
import
torch
import
triton
import
triton
from
deep_gemm
import
get_col_major_tma_aligned_tensor
from
sgl_kernel
import
fp8_blockwise_scaled_mm
from
sgl_kernel
import
fp8_blockwise_scaled_mm
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
from
sglang.srt.layers.quantization.fp8_kernel
import
w8a8_block_fp8_matmul
def
get_weight_shapes
(
args
):
def
get_weight_shapes
(
args
):
models_tps
=
list
(
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
))
models_tps
=
list
(
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
))
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
# cannot TP
# cannot TP
total
=
[
total
=
[
#
(512 + 64, 7168),
# this weight is not supported by current kernel
(
512
+
64
,
7168
),
((
128
+
64
)
*
128
,
7168
),
((
128
+
64
)
*
128
,
7168
),
(
128
*
(
128
+
128
),
512
),
(
128
*
(
128
+
128
),
512
),
(
7168
,
16384
),
(
7168
,
16384
),
...
@@ -52,6 +56,23 @@ def cdiv(a: int, b: int) -> int:
...
@@ -52,6 +56,23 @@ def cdiv(a: int, b: int) -> int:
return
-
(
a
//
-
b
)
return
-
(
a
//
-
b
)
def
fp8_gemm_deepgemm
(
x_fp8
:
torch
.
Tensor
,
x_scale
:
torch
.
Tensor
,
y_fp8
:
torch
.
Tensor
,
y_scale
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
):
"""DeepGEMM implementation of FP8 GEMM"""
out
=
torch
.
empty
((
m
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
# Run DeepGEMM kernel
deep_gemm
.
gemm_fp8_fp8_bf16_nt
((
x_fp8
,
x_scale
),
(
y_fp8
,
y_scale
),
out
)
return
out
def
scale_shape
(
shape
,
group_shape
):
def
scale_shape
(
shape
,
group_shape
):
assert
len
(
shape
)
==
len
(
group_shape
)
assert
len
(
shape
)
==
len
(
group_shape
)
return
tuple
(
cdiv
(
shape
[
i
],
group_shape
[
i
])
for
i
in
range
(
len
(
group_shape
)))
return
tuple
(
cdiv
(
shape
[
i
],
group_shape
[
i
])
for
i
in
range
(
len
(
group_shape
)))
...
@@ -60,12 +81,12 @@ def scale_shape(shape, group_shape):
...
@@ -60,12 +81,12 @@ def scale_shape(shape, group_shape):
@
triton
.
testing
.
perf_report
(
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
],
x_vals
=
[
1
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
],
x_log
=
False
,
x_log
=
False
,
line_arg
=
"provider"
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm"
,
"sgl-kernel"
],
line_vals
=
[
"vllm"
,
"sgl-kernel"
,
"triton"
,
"deepgemm"
],
line_names
=
[
"vllm
fp8 blockwise gemm"
,
"sgl-kernel fp8 blockwise
gemm"
],
line_names
=
[
"vllm
"
,
"sgl-kernel"
,
"sglang triton"
,
"deep
gemm"
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
)],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
),
(
"red"
,
"-"
),
(
"yellow"
,
"-"
)],
ylabel
=
"GB/s"
,
ylabel
=
"GB/s"
,
plot_name
=
"fp8 blockwise scaled matmul"
,
plot_name
=
"fp8 blockwise scaled matmul"
,
args
=
{},
args
=
{},
...
@@ -80,7 +101,7 @@ def benchmark(batch_size, provider, N, K):
...
@@ -80,7 +101,7 @@ def benchmark(batch_size, provider, N, K):
a_fp8
=
a_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
a_fp8
=
a_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
b_fp32
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
-
0.5
)
*
2
*
fp8_max
b_fp32
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
-
0.5
)
*
2
*
fp8_max
b_fp8
=
b_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
.
t
()
b_fp8
=
b_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
scale_a_group_shape
=
(
1
,
128
)
scale_a_group_shape
=
(
1
,
128
)
scale_b_group_shape
=
(
128
,
128
)
scale_b_group_shape
=
(
128
,
128
)
...
@@ -89,11 +110,11 @@ def benchmark(batch_size, provider, N, K):
...
@@ -89,11 +110,11 @@ def benchmark(batch_size, provider, N, K):
scale_a
=
torch
.
randn
(
scale_a_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_a
=
torch
.
randn
(
scale_a_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
randn
(
scale_b_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
randn
(
scale_b_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_a
=
scale_a
.
t
().
contiguous
().
t
()
scale_b
=
scale_b
.
t
().
contiguous
().
t
()
quantiles
=
[
0.5
,
0.2
,
0.8
]
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"sgl-kernel"
:
if
provider
==
"sgl-kernel"
:
scale_a
=
scale_a
.
t
().
contiguous
().
t
()
b_fp8
,
scale_b
=
b_fp8
.
t
(),
scale_b
.
t
()
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
fp8_blockwise_scaled_mm
(
lambda
:
fp8_blockwise_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
torch
.
float16
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
torch
.
float16
...
@@ -101,19 +122,28 @@ def benchmark(batch_size, provider, N, K):
...
@@ -101,19 +122,28 @@ def benchmark(batch_size, provider, N, K):
quantiles
=
quantiles
,
quantiles
=
quantiles
,
)
)
if
provider
==
"vllm"
:
if
provider
==
"vllm"
:
scale_a
=
scale_a
.
t
().
contiguous
().
t
()
b_fp8
,
scale_b
=
b_fp8
.
t
(),
scale_b
.
t
()
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
torch
.
float16
),
lambda
:
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
torch
.
float16
),
quantiles
=
quantiles
,
quantiles
=
quantiles
,
)
)
gbps
=
(
if
provider
==
"triton"
:
lambda
ms
:
(
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
(
2
*
M
*
N
*
K
-
M
*
N
)
*
a_fp8
.
element_size
()
lambda
:
w8a8_block_fp8_matmul
(
+
(
3
*
M
*
N
)
*
scale_a
.
element_size
()
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
[
128
,
128
],
torch
.
float16
),
quantiles
=
quantiles
,
)
if
provider
==
"deepgemm"
:
scale_a_col_major
=
get_col_major_tma_aligned_tensor
(
scale_a
.
clone
())
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
fp8_gemm_deepgemm
(
a_fp8
,
scale_a_col_major
,
b_fp8
,
scale_b
,
M
,
N
,
K
),
quantiles
=
quantiles
,
)
)
*
1e-9
return
ms
*
1000
,
max_ms
*
1000
,
min_ms
*
1000
# convert to ms
/
(
ms
*
1e-3
)
)
return
gbps
(
ms
),
gbps
(
max_ms
),
gbps
(
min_ms
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
@@ -136,6 +166,9 @@ if __name__ == "__main__":
...
@@ -136,6 +166,9 @@ if __name__ == "__main__":
NK_model_names
=
get_weight_shapes
(
args
)
NK_model_names
=
get_weight_shapes
(
args
)
for
N
,
K
,
model_name
in
NK_model_names
:
for
N
,
K
,
model_name
in
NK_model_names
:
if
N
%
128
!=
0
or
K
%
128
!=
0
:
print
(
f
"Skip
{
N
=
}
,
{
K
=
}
now"
)
continue
print
(
f
"
{
model_name
}
N=
{
N
}
K=
{
K
}
: "
)
print
(
f
"
{
model_name
}
N=
{
N
}
K=
{
K
}
: "
)
benchmark
.
run
(
benchmark
.
run
(
print_data
=
True
,
print_data
=
True
,
...
...
sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
deleted
100644 → 0
View file @
7f875f12
// Adapt from
// https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/collective/collective_buildler.hpp
// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl
// clang-format off
#pragma once
#include <cutlass/gemm/collective/builders/sm90_gmma_builder.inl>
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
::
gemm
::
collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_SS (BlockScaled Builders)
template
<
class
ElementA
,
class
GmemLayoutATag
,
int
AlignmentA
,
class
ElementB
,
class
GmemLayoutBTag
,
int
AlignmentB
,
class
ElementAccumulator
,
class
TileShape_MNK
,
class
ClusterShape_MNK
,
class
StageCountType
,
int
ScaleGranularityM
>
struct
CollectiveBuilder
<
arch
::
Sm90
,
arch
::
OpClassTensorOp
,
ElementA
,
GmemLayoutATag
,
AlignmentA
,
ElementB
,
GmemLayoutBTag
,
AlignmentB
,
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
StageCountType
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>
,
cute
::
enable_if_t
<
not
detail
::
is_use_rmem_A
<
ElementA
,
GmemLayoutATag
,
ElementB
,
GmemLayoutBTag
>
()
>
>
{
using
KernelScheduleType
=
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>
;
static_assert
(
is_static
<
TileShape_MNK
>::
value
);
static_assert
(
is_static
<
ClusterShape_MNK
>::
value
);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert
(
cutlass
::
detail
::
dependent_false
<
ElementA
>
,
"Unsupported Toolkit for SM90 Collective Builder
\n
"
);
#endif
static_assert
(
detail
::
is_aligned
<
ElementA
,
AlignmentA
,
ElementB
,
AlignmentB
,
detail
::
tma_alignment_bytes
>
(),
"Should meet TMA alignment requirement
\n
"
);
static
constexpr
bool
IsArrayOfPointersGemm
=
(
cute
::
is_any_of_v
<
KernelScheduleType
,
KernelPtrArrayTmaWarpSpecializedCooperative
,
KernelPtrArrayTmaWarpSpecializedPingpong
>
);
static
constexpr
bool
IsFP8Input
=
detail
::
is_input_fp8
<
ElementA
,
ElementB
>
();
static_assert
((
!
IsFP8Input
||
!
IsArrayOfPointersGemm
),
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now."
);
// For fp32 types, map to tf32 MMA value type
using
ElementAMma
=
cute
::
conditional_t
<
cute
::
is_same_v
<
ElementA
,
float
>
,
tfloat32_t
,
ElementA
>
;
using
ElementBMma
=
cute
::
conditional_t
<
cute
::
is_same_v
<
ElementB
,
float
>
,
tfloat32_t
,
ElementB
>
;
static
constexpr
cute
::
GMMA
::
Major
GmmaMajorA
=
detail
::
gmma_ss_tag_to_major_A
<
ElementAMma
,
GmemLayoutATag
>
();
static
constexpr
cute
::
GMMA
::
Major
GmmaMajorB
=
detail
::
gmma_ss_tag_to_major_B
<
ElementBMma
,
GmemLayoutBTag
>
();
static
constexpr
bool
IsCooperative
=
cute
::
is_any_of_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedCooperative
,
KernelPtrArrayTmaWarpSpecializedCooperative
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>>
;
using
AtomLayoutMNK
=
cute
::
conditional_t
<
IsCooperative
,
Layout
<
Shape
<
_2
,
_1
,
_1
>>
,
Layout
<
Shape
<
_1
,
_1
,
_1
>>>
;
using
TiledMma
=
decltype
(
cute
::
make_tiled_mma
(
cute
::
GMMA
::
ss_op_selector
<
ElementAMma
,
ElementBMma
,
ElementAccumulator
,
TileShape_MNK
,
GmmaMajorA
,
GmmaMajorB
>
(),
AtomLayoutMNK
{}));
using
GmemTiledCopyA
=
decltype
(
detail
::
sm90_cluster_shape_to_tma_atom
(
shape
<
1
>
(
ClusterShape_MNK
{})));
using
GmemTiledCopyB
=
decltype
(
detail
::
sm90_cluster_shape_to_tma_atom
(
shape
<
0
>
(
ClusterShape_MNK
{})));
using
SmemLayoutAtomA
=
decltype
(
detail
::
ss_smem_selector
<
GmmaMajorA
,
ElementAMma
,
decltype
(
cute
::
get
<
0
>
(
TileShape_MNK
{})),
decltype
(
cute
::
get
<
2
>
(
TileShape_MNK
{}))
>
());
using
SmemLayoutAtomB
=
decltype
(
detail
::
ss_smem_selector
<
GmmaMajorB
,
ElementBMma
,
decltype
(
cute
::
get
<
1
>
(
TileShape_MNK
{})),
decltype
(
cute
::
get
<
2
>
(
TileShape_MNK
{}))
>
());
static
constexpr
size_t
TensorMapStorage
=
IsArrayOfPointersGemm
?
sizeof
(
cute
::
TmaDescriptor
)
*
2
/* for A and B */
:
0
;
static
constexpr
int
KernelSmemCarveout
=
static_cast
<
int
>
(
TensorMapStorage
);
static
constexpr
int
PipelineStages
=
detail
::
compute_stage_count_or_override
<
detail
::
sm90_smem_capacity_bytes
-
KernelSmemCarveout
,
ElementAMma
,
ElementBMma
,
TileShape_MNK
>
(
StageCountType
{});
using
DispatchPolicy
=
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
<
PipelineStages
,
ClusterShape_MNK
,
KernelScheduleType
,
ScaleGranularityM
>
;
using
SmemCopyAtomA
=
void
;
using
SmemCopyAtomB
=
void
;
using
CollectiveOp
=
CollectiveMma
<
DispatchPolicy
,
TileShape_MNK
,
ElementA
,
TagToStrideA_t
<
GmemLayoutATag
>
,
ElementB
,
TagToStrideB_t
<
GmemLayoutBTag
>
,
TiledMma
,
GmemTiledCopyA
,
SmemLayoutAtomA
,
SmemCopyAtomA
,
cute
::
identity
,
GmemTiledCopyB
,
SmemLayoutAtomB
,
SmemCopyAtomB
,
cute
::
identity
>
;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
deleted
100644 → 0
View file @
7f875f12
This diff is collapsed.
Click to expand it.
sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
deleted
100644 → 0
View file @
7f875f12
// Adapt from https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
#pragma once
#include <cutlass/gemm/dispatch_policy.hpp>
namespace
cutlass
::
gemm
{
//////////////////////////////////////////////////////////////////////////////
// FP8 related policies (including Blocked Scaled Accumulation)
// `ScaleGranularityM` specifies scaling granularity along M, while zero-value
// `ScaleGranularityM` indicates that scaling granularity is
// `size<0>(TileShape_MNK{})` along M.
template
<
int
ScaleGranularityM
=
0
>
struct
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
:
KernelTmaWarpSpecializedCooperative
{};
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
// specialized dynamic schedule For FP8 kernels with Block Scaling
template
<
int
Stages_
,
class
ClusterShape_
=
Shape
<
_1
,
_1
,
_1
>,
class
KernelSchedule
=
KernelTmaWarpSpecialized
,
int
ScaleGranularityM
=
0
// `ScaleGranularityM` specifies scaling granularity along M,
// while zero-value `ScaleGranularityM` indicates that scaling
// granularity is `size<0>(TileShape_MNK{})` along M.
>
struct
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
:
MainloopSm90TmaGmmaWarpSpecialized
<
Stages_
,
ClusterShape_
,
KernelSchedule
>
{
static_assert
(
cute
::
is_same_v
<
KernelSchedule
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>>
,
"KernelSchedule must be one of the warp specialized policies"
);
};
//////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::gemm
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
View file @
ebf495f0
...
@@ -30,13 +30,16 @@
...
@@ -30,13 +30,16 @@
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
#include <cutlass/util/packed_stride.hpp>
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "utils.h"
#include "utils.h"
using
namespace
cute
;
using
namespace
cute
;
template
<
typename
OutType
,
typename
TileShape
,
typename
ClusterShape
,
int
ScaleGranularityM
=
1
>
template
<
typename
SchedulerType
,
typename
OutType
,
typename
TileShape
,
typename
ClusterShape
,
typename
ScaleGranularity
>
void
launch_sm90_fp8_blockwise_scaled_mm
(
void
launch_sm90_fp8_blockwise_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
a
,
...
@@ -63,6 +66,9 @@ void launch_sm90_fp8_blockwise_scaled_mm(
...
@@ -63,6 +66,9 @@ void launch_sm90_fp8_blockwise_scaled_mm(
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
constexpr
int
AlignmentD
=
AlignmentC
;
constexpr
int
AlignmentD
=
AlignmentC
;
static
constexpr
int
ScaleGranularityM
=
size
<
0
>
(
ScaleGranularity
{});
static
constexpr
int
ScaleGranularityN
=
size
<
1
>
(
ScaleGranularity
{});
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
...
@@ -70,7 +76,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
...
@@ -70,7 +76,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
using
StoreEpilogueCompute
=
typename
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
>
;
using
StoreEpilogueCompute
=
typename
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
>
;
using
KernelSchedule
=
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaled
SubGroupM
Accum
<
ScaleGranularityM
>
;
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum
<
ScaleGranularityM
,
ScaleGranularityN
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
ArchTag
,
OperatorClass
,
OperatorClass
,
...
@@ -108,7 +114,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
...
@@ -108,7 +114,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
CollectiveMainloop
,
CollectiveMainloop
,
CollectiveEpilogue
,
CollectiveEpilogue
,
cutlass
::
gemm
::
Persistent
Scheduler
>
;
Scheduler
Type
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
Gemm
gemm_op
;
Gemm
gemm_op
;
...
@@ -299,8 +305,26 @@ void sm90_fp8_blockwise_dispatch_shape(
...
@@ -299,8 +305,26 @@ void sm90_fp8_blockwise_dispatch_shape(
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
const
torch
::
Tensor
&
scales_b
)
{
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
launch_sm90_fp8_blockwise_scaled_mm
<
OutType
,
TileShape
,
ClusterShape
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
using
ScaleGranularity
=
Shape
<
_1
,
_128
,
_128
>
;
auto
k
=
a
.
size
(
1
);
auto
n
=
b
.
size
(
1
);
if
(
k
>
3
*
n
)
{
launch_sm90_fp8_blockwise_scaled_mm
<
cutlass
::
gemm
::
StreamKScheduler
,
OutType
,
TileShape
,
ClusterShape
,
ScaleGranularity
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
else
{
launch_sm90_fp8_blockwise_scaled_mm
<
cutlass
::
gemm
::
PersistentScheduler
,
OutType
,
TileShape
,
ClusterShape
,
ScaleGranularity
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
}
}
template
<
typename
OutType
>
template
<
typename
OutType
>
...
@@ -372,10 +396,11 @@ torch::Tensor fp8_blockwise_scaled_mm(
...
@@ -372,10 +396,11 @@ torch::Tensor fp8_blockwise_scaled_mm(
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
if
(
sm_version
==
90
)
{
if
(
sm_version
==
90
)
{
torch
::
Tensor
scales_b_contiguous
=
scales_b
.
contiguous
();
if
(
out_dtype
==
torch
::
kBFloat16
)
{
if
(
out_dtype
==
torch
::
kBFloat16
)
{
sm90_fp8_blockwise_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
);
sm90_fp8_blockwise_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
_contiguous
);
}
else
{
}
else
{
sm90_fp8_blockwise_dispatch_shape
<
cutlass
::
half_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
);
sm90_fp8_blockwise_dispatch_shape
<
cutlass
::
half_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
_contiguous
);
}
}
return
out
;
return
out
;
}
}
...
...
sgl-kernel/tests/test_fp8_blockwise_gemm.py
View file @
ebf495f0
...
@@ -82,9 +82,9 @@ def _test_accuracy_once(M, N, K, out_dtype, device):
...
@@ -82,9 +82,9 @@ def _test_accuracy_once(M, N, K, out_dtype, device):
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, out_dtype=
{
out_dtype
}
: OK"
)
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, out_dtype=
{
out_dtype
}
: OK"
)
@
pytest
.
mark
.
parametrize
(
"M"
,
[
1
,
128
,
512
,
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"M"
,
[
1
,
3
,
5
,
127
,
128
,
512
,
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
128
,
512
,
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
128
,
512
,
1024
,
4096
,
8192
,
14080
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
512
,
1024
,
4096
,
8192
,
16384
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
512
,
1024
,
4096
,
8192
,
14080
,
16384
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
def
test_accuracy
(
M
,
N
,
K
,
out_dtype
):
def
test_accuracy
(
M
,
N
,
K
,
out_dtype
):
_test_accuracy_once
(
M
,
N
,
K
,
out_dtype
,
"cuda"
)
_test_accuracy_once
(
M
,
N
,
K
,
out_dtype
,
"cuda"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment