Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
640363ad
Unverified
Commit
640363ad
authored
Feb 13, 2025
by
yizhang2077
Committed by
GitHub
Feb 13, 2025
Browse files
support blockwise fp8 matmul kernel (#3267)
parent
8616357a
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1366 additions
and
0 deletions
+1366
-0
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
+148
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+2
-0
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
...cutlass_extensions/gemm/collective/collective_builder.hpp
+125
-0
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
...mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
+733
-0
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
...l-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
+33
-0
sgl-kernel/src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu
+191
-0
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
+5
-0
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+10
-0
sgl-kernel/src/sgl-kernel/torch_extension.cc
sgl-kernel/src/sgl-kernel/torch_extension.cc
+6
-0
sgl-kernel/tests/test_fp8_blockwise_gemm.py
sgl-kernel/tests/test_fp8_blockwise_gemm.py
+112
-0
No files found.
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
0 → 100644
View file @
640363ad
import
argparse
import
copy
import
itertools
import
torch
import
triton
from
sgl_kernel
import
fp8_blockwise_scaled_mm
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
def
get_weight_shapes
(
args
):
models_tps
=
list
(
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
))
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
# cannot TP
total
=
[
# (512 + 64, 7168), # this weight is not supported by current kernel
((
128
+
64
)
*
128
,
7168
),
(
128
*
(
128
+
128
),
512
),
(
7168
,
16384
),
(
7168
,
18432
),
]
# N can TP
n_tp
=
[
(
18432
*
2
,
7168
),
((
128
+
64
)
*
128
,
7168
),
(
128
*
(
128
+
128
),
512
),
(
24576
,
1536
),
(
4096
,
7168
),
]
# K can TP
k_tp
=
[(
7168
,
18432
),
(
7168
,
16384
),
(
7168
,
2048
)]
# only support Deepseek-V3
SUPPORT_MODEL
=
[
"deepseek-ai/DeepSeek-V3"
]
weight_shapes
=
[]
for
model
,
tp_size
in
models_tps
:
assert
model
in
SUPPORT_MODEL
for
t
in
total
:
new_t
=
[
t
[
0
],
t
[
1
],
model
]
weight_shapes
.
append
(
new_t
)
for
n_t
in
n_tp
:
new_t
=
[
n_t
[
0
]
//
tp_size
,
n_t
[
1
],
model
]
weight_shapes
.
append
(
new_t
)
for
k_t
in
k_tp
:
new_t
=
[
k_t
[
0
],
k_t
[
1
]
//
tp_size
,
model
]
weight_shapes
.
append
(
new_t
)
return
weight_shapes
def
cdiv
(
a
:
int
,
b
:
int
)
->
int
:
"""Ceiling division."""
return
-
(
a
//
-
b
)
def
scale_shape
(
shape
,
group_shape
):
assert
len
(
shape
)
==
len
(
group_shape
)
return
tuple
(
cdiv
(
shape
[
i
],
group_shape
[
i
])
for
i
in
range
(
len
(
group_shape
)))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm"
,
"sgl-kernel"
],
line_names
=
[
"vllm fp8 blockwise gemm"
,
"sgl-kernel fp8 blockwise gemm"
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
)],
ylabel
=
"GB/s"
,
plot_name
=
"fp8 blockwise scaled matmul"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
M
=
batch_size
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
a_fp32
=
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
-
0.5
)
*
2
*
fp8_max
a_fp8
=
a_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
b_fp32
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
-
0.5
)
*
2
*
fp8_max
b_fp8
=
b_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
).
t
()
scale_a_group_shape
=
(
1
,
128
)
scale_b_group_shape
=
(
128
,
128
)
scale_a_shape
=
scale_shape
(
a_fp8
.
shape
,
scale_a_group_shape
)
scale_b_shape
=
scale_shape
(
b_fp8
.
shape
,
scale_b_group_shape
)
scale_a
=
torch
.
randn
(
scale_a_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
randn
(
scale_b_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_a
=
scale_a
.
t
().
contiguous
().
t
()
scale_b
=
scale_b
.
t
().
contiguous
().
t
()
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"sgl-kernel"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
fp8_blockwise_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
torch
.
float16
),
quantiles
=
quantiles
,
)
if
provider
==
"vllm"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
torch
.
float16
),
quantiles
=
quantiles
,
)
gbps
=
(
lambda
ms
:
(
(
2
*
M
*
N
*
K
-
M
*
N
)
*
a_fp8
.
element_size
()
+
(
3
*
M
*
N
)
*
scale_a
.
element_size
()
)
*
1e-9
/
(
ms
*
1e-3
)
)
return
gbps
(
ms
),
gbps
(
max_ms
),
gbps
(
min_ms
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"deepseek-ai/DeepSeek-V3"
],
help
=
"List of models to benchmark"
,
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
],
help
=
"List of tensor parallel sizes"
,
)
args
=
parser
.
parse_args
()
NK_model_names
=
get_weight_shapes
(
args
)
for
N
,
K
,
model_name
in
NK_model_names
:
print
(
f
"
{
model_name
}
N=
{
N
}
K=
{
K
}
: "
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_fp8_blockwise_res"
,
N
=
N
,
K
=
K
,
)
print
(
"Benchmark finished!"
)
sgl-kernel/setup.py
View file @
640363ad
...
@@ -96,6 +96,7 @@ sources = [
...
@@ -96,6 +96,7 @@ sources = [
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu"
,
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu"
,
"src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu"
,
"src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu"
,
"src/sgl-kernel/csrc/eagle_utils.cu"
,
"src/sgl-kernel/csrc/eagle_utils.cu"
,
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
640363ad
...
@@ -14,6 +14,7 @@ from sgl_kernel.ops import (
...
@@ -14,6 +14,7 @@ from sgl_kernel.ops import (
build_tree_kernel_efficient
,
build_tree_kernel_efficient
,
custom_dispose
,
custom_dispose
,
custom_reduce
,
custom_reduce
,
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
,
fp8_scaled_mm
,
fused_add_rmsnorm
,
fused_add_rmsnorm
,
gelu_and_mul
,
gelu_and_mul
,
...
@@ -44,6 +45,7 @@ __all__ = [
...
@@ -44,6 +45,7 @@ __all__ = [
"bmm_fp8"
,
"bmm_fp8"
,
"custom_dispose"
,
"custom_dispose"
,
"custom_reduce"
,
"custom_reduce"
,
"fp8_blockwise_scaled_mm"
,
"fp8_scaled_mm"
,
"fp8_scaled_mm"
,
"fused_add_rmsnorm"
,
"fused_add_rmsnorm"
,
"gelu_and_mul"
,
"gelu_and_mul"
,
...
...
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
0 → 100644
View file @
640363ad
// Adapt from
// https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/collective/collective_buildler.hpp
// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl
// clang-format off
#pragma once
#include <cutlass/gemm/collective/builders/sm90_gmma_builder.inl>
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
::
gemm
::
collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_SS (BlockScaled Builders)
template
<
class
ElementA
,
class
GmemLayoutATag
,
int
AlignmentA
,
class
ElementB
,
class
GmemLayoutBTag
,
int
AlignmentB
,
class
ElementAccumulator
,
class
TileShape_MNK
,
class
ClusterShape_MNK
,
class
StageCountType
,
int
ScaleGranularityM
>
struct
CollectiveBuilder
<
arch
::
Sm90
,
arch
::
OpClassTensorOp
,
ElementA
,
GmemLayoutATag
,
AlignmentA
,
ElementB
,
GmemLayoutBTag
,
AlignmentB
,
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
StageCountType
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>
,
cute
::
enable_if_t
<
not
detail
::
is_use_rmem_A
<
ElementA
,
GmemLayoutATag
,
ElementB
,
GmemLayoutBTag
>
()
>
>
{
using
KernelScheduleType
=
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>
;
static_assert
(
is_static
<
TileShape_MNK
>::
value
);
static_assert
(
is_static
<
ClusterShape_MNK
>::
value
);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert
(
cutlass
::
detail
::
dependent_false
<
ElementA
>
,
"Unsupported Toolkit for SM90 Collective Builder
\n
"
);
#endif
static_assert
(
detail
::
is_aligned
<
ElementA
,
AlignmentA
,
ElementB
,
AlignmentB
,
detail
::
tma_alignment_bytes
>
(),
"Should meet TMA alignment requirement
\n
"
);
static
constexpr
bool
IsArrayOfPointersGemm
=
(
cute
::
is_any_of_v
<
KernelScheduleType
,
KernelPtrArrayTmaWarpSpecializedCooperative
,
KernelPtrArrayTmaWarpSpecializedPingpong
>
);
static
constexpr
bool
IsFP8Input
=
detail
::
is_input_fp8
<
ElementA
,
ElementB
>
();
static_assert
((
!
IsFP8Input
||
!
IsArrayOfPointersGemm
),
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now."
);
// For fp32 types, map to tf32 MMA value type
using
ElementAMma
=
cute
::
conditional_t
<
cute
::
is_same_v
<
ElementA
,
float
>
,
tfloat32_t
,
ElementA
>
;
using
ElementBMma
=
cute
::
conditional_t
<
cute
::
is_same_v
<
ElementB
,
float
>
,
tfloat32_t
,
ElementB
>
;
static
constexpr
cute
::
GMMA
::
Major
GmmaMajorA
=
detail
::
gmma_ss_tag_to_major_A
<
ElementAMma
,
GmemLayoutATag
>
();
static
constexpr
cute
::
GMMA
::
Major
GmmaMajorB
=
detail
::
gmma_ss_tag_to_major_B
<
ElementBMma
,
GmemLayoutBTag
>
();
static
constexpr
bool
IsCooperative
=
cute
::
is_any_of_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedCooperative
,
KernelPtrArrayTmaWarpSpecializedCooperative
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>>
;
using
AtomLayoutMNK
=
cute
::
conditional_t
<
IsCooperative
,
Layout
<
Shape
<
_2
,
_1
,
_1
>>
,
Layout
<
Shape
<
_1
,
_1
,
_1
>>>
;
using
TiledMma
=
decltype
(
cute
::
make_tiled_mma
(
cute
::
GMMA
::
ss_op_selector
<
ElementAMma
,
ElementBMma
,
ElementAccumulator
,
TileShape_MNK
,
GmmaMajorA
,
GmmaMajorB
>
(),
AtomLayoutMNK
{}));
using
GmemTiledCopyA
=
decltype
(
detail
::
sm90_cluster_shape_to_tma_atom
(
shape
<
1
>
(
ClusterShape_MNK
{})));
using
GmemTiledCopyB
=
decltype
(
detail
::
sm90_cluster_shape_to_tma_atom
(
shape
<
0
>
(
ClusterShape_MNK
{})));
using
SmemLayoutAtomA
=
decltype
(
detail
::
ss_smem_selector
<
GmmaMajorA
,
ElementAMma
,
decltype
(
cute
::
get
<
0
>
(
TileShape_MNK
{})),
decltype
(
cute
::
get
<
2
>
(
TileShape_MNK
{}))
>
());
using
SmemLayoutAtomB
=
decltype
(
detail
::
ss_smem_selector
<
GmmaMajorB
,
ElementBMma
,
decltype
(
cute
::
get
<
1
>
(
TileShape_MNK
{})),
decltype
(
cute
::
get
<
2
>
(
TileShape_MNK
{}))
>
());
static
constexpr
size_t
TensorMapStorage
=
IsArrayOfPointersGemm
?
sizeof
(
cute
::
TmaDescriptor
)
*
2
/* for A and B */
:
0
;
static
constexpr
int
KernelSmemCarveout
=
static_cast
<
int
>
(
TensorMapStorage
);
static
constexpr
int
PipelineStages
=
detail
::
compute_stage_count_or_override
<
detail
::
sm90_smem_capacity_bytes
-
KernelSmemCarveout
,
ElementAMma
,
ElementBMma
,
TileShape_MNK
>
(
StageCountType
{});
using
DispatchPolicy
=
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
<
PipelineStages
,
ClusterShape_MNK
,
KernelScheduleType
,
ScaleGranularityM
>
;
using
SmemCopyAtomA
=
void
;
using
SmemCopyAtomB
=
void
;
using
CollectiveOp
=
CollectiveMma
<
DispatchPolicy
,
TileShape_MNK
,
ElementA
,
TagToStrideA_t
<
GmemLayoutATag
>
,
ElementB
,
TagToStrideB_t
<
GmemLayoutBTag
>
,
TiledMma
,
GmemTiledCopyA
,
SmemLayoutAtomA
,
SmemCopyAtomA
,
cute
::
identity
,
GmemTiledCopyB
,
SmemLayoutAtomB
,
SmemCopyAtomB
,
cute
::
identity
>
;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
0 → 100644
View file @
640363ad
This diff is collapsed.
Click to expand it.
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
0 → 100644
View file @
640363ad
// Adapt from https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
#pragma once
#include <cutlass/gemm/dispatch_policy.hpp>
namespace
cutlass
::
gemm
{
//////////////////////////////////////////////////////////////////////////////
// FP8 related policies (including Blocked Scaled Accumulation)
// `ScaleGranularityM` specifies scaling granularity along M, while zero-value
// `ScaleGranularityM` indicates that scaling granularity is
// `size<0>(TileShape_MNK{})` along M.
template
<
int
ScaleGranularityM
=
0
>
struct
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
:
KernelTmaWarpSpecializedCooperative
{};
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
// specialized dynamic schedule For FP8 kernels with Block Scaling
template
<
int
Stages_
,
class
ClusterShape_
=
Shape
<
_1
,
_1
,
_1
>,
class
KernelSchedule
=
KernelTmaWarpSpecialized
,
int
ScaleGranularityM
=
0
// `ScaleGranularityM` specifies scaling granularity along M,
// while zero-value `ScaleGranularityM` indicates that scaling
// granularity is `size<0>(TileShape_MNK{})` along M.
>
struct
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
:
MainloopSm90TmaGmmaWarpSpecialized
<
Stages_
,
ClusterShape_
,
KernelSchedule
>
{
static_assert
(
cute
::
is_same_v
<
KernelSchedule
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>>
,
"KernelSchedule must be one of the warp specialized policies"
);
};
//////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::gemm
sgl-kernel/src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu
0 → 100644
View file @
640363ad
#include <ATen/cuda/CUDAContext.h>
#include <cudaTypedefs.h>
#include <cutlass/arch/arch.h>
#include <cutlass/arch/memory.h>
#include <cutlass/arch/mma.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/thread/activation.h>
#include <cutlass/epilogue/thread/linear_combination.h>
#include <cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
#include <cutlass/gemm/thread/mma.h>
#include <cutlass/layout/matrix.h>
#include <cutlass/matrix_coord.h>
#include <cutlass/numeric_types.h>
#include <cutlass/tensor_ref.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <cutlass/epilogue/collective/collective_builder.hpp>
#include <cutlass/epilogue/collective/default_epilogue.hpp>
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
#include <cutlass/gemm/collective/collective_builder.hpp>
#include <cutlass/gemm/dispatch_policy.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "utils.h"
using
namespace
cute
;
template
<
typename
OutType
,
typename
TileShape
,
typename
ClusterShape
,
int
ScaleGranularityM
=
1
>
void
launch_sm90_fp8_blockwise_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
using
ElementAccumulator
=
float
;
using
ElementCompute
=
float
;
using
ElementBlockScale
=
float
;
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
using
ElementB
=
cutlass
::
float_e4m3_t
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
using
ElementC
=
void
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
OutType
>::
value
;
using
ElementD
=
OutType
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
constexpr
int
AlignmentD
=
AlignmentC
;
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
EpilogueTileType
=
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
;
using
StoreEpilogueCompute
=
typename
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
>
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
TileShape
,
ClusterShape
,
EpilogueTileType
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
EpilogueSchedule
,
StoreEpilogueCompute
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
,
AlignmentA
,
ElementB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
Gemm
gemm_op
;
int
m
=
a
.
size
(
0
);
int
k
=
a
.
size
(
1
);
int
n
=
b
.
size
(
1
);
auto
a_ptr
=
static_cast
<
ElementA
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementB
*>
(
b
.
data_ptr
());
auto
o_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
auto
a_s_ptr
=
static_cast
<
ElementBlockScale
*>
(
scales_a
.
data_ptr
());
auto
b_s_ptr
=
static_cast
<
ElementBlockScale
*>
(
scales_b
.
data_ptr
());
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideC
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
StrideD
;
StrideA
stride_a
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
m
,
k
,
1
));
StrideB
stride_b
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
n
,
k
,
1
));
StrideC
stride_c
;
StrideD
stride_d
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
cute
::
make_shape
(
m
,
n
,
1
));
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
stride_a
,
b_ptr
,
stride_b
,
4
,
a_s_ptr
,
b_s_ptr
};
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{{},
nullptr
,
stride_d
,
o_ptr
,
stride_d
};
typename
Gemm
::
Arguments
args
=
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
{
m
,
n
,
k
,
1
},
mainloop_args
,
epilogue_args
,
};
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
can_implement
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
can_implement
))
auto
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
status
))
}
template
<
typename
OutType
>
void
sm90_fp8_blockwise_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
launch_sm90_fp8_blockwise_scaled_mm
<
OutType
,
TileShape
,
ClusterShape
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
)
{
TORCH_CHECK
(
mat_a
.
is_cuda
(),
"mat_a must be a CUDA tensor"
);
TORCH_CHECK
(
mat_b
.
is_cuda
(),
"mat_b must be a CUDA tensor"
);
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
TORCH_CHECK
(
mat_b
.
dim
()
==
2
,
"mat_b must be a 2D tensor"
);
TORCH_CHECK
(
mat_a
.
stride
(
1
)
==
1
,
"mat_a must be a row major tensor"
);
TORCH_CHECK
(
mat_b
.
stride
(
0
)
==
1
,
"mat_a must be a column major tensor"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
0
),
"mat_a and mat_b shapes cannot be multiplied"
);
TORCH_CHECK
((
mat_a
.
size
(
1
)
*
mat_a
.
element_size
())
%
16
==
0
,
"mat_a must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
((
mat_b
.
size
(
0
)
*
mat_b
.
element_size
())
%
16
==
0
,
"mat_b must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
(
mat_a
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_a must be Float8_e4m3fn"
);
TORCH_CHECK
(
mat_b
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_b must be Float8_e4m3fn"
);
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
||
out_dtype
==
torch
::
kBFloat16
,
"out_dtype must be Half or BFloat16"
);
auto
is_contiguous_vector
=
[](
const
torch
::
Tensor
&
t
)
{
auto
t_sizes
=
t
.
sizes
();
return
t
.
is_contiguous
()
&&
(
t
.
dim
()
==
1
||
(
t
.
dim
()
==
2
&&
*
std
::
min_element
(
t_sizes
.
begin
(),
t_sizes
.
end
())
==
1
));
};
TORCH_CHECK
(
mat_a
.
size
(
0
)
==
scales_a
.
size
(
0
),
"size of scales_a is not matched"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
/
128
==
scales_a
.
size
(
1
),
"size of scales_a is not matched"
);
TORCH_CHECK
(
scales_a
.
stride
(
0
)
==
1
||
is_contiguous_vector
(
scales_a
),
"scales_a must be M major"
);
TORCH_CHECK
(
mat_b
.
size
(
0
)
/
128
==
scales_b
.
size
(
0
),
"size of scales_b is not matched"
);
TORCH_CHECK
(
mat_b
.
size
(
1
)
/
128
==
scales_b
.
size
(
1
),
"size of scales_b is not matched"
);
TORCH_CHECK
(
scales_b
.
stride
(
0
)
==
1
||
is_contiguous_vector
(
scales_b
),
"scales_b must be K major"
);
TORCH_CHECK
(
scales_a
.
scalar_type
()
==
torch
::
kFloat32
,
"scales_a must be Float32"
);
TORCH_CHECK
(
scales_b
.
scalar_type
()
==
torch
::
kFloat32
,
"scales_b must be Float32"
);
torch
::
Tensor
out
=
torch
::
empty
({
mat_a
.
size
(
0
),
mat_b
.
size
(
1
)},
mat_a
.
options
().
dtype
(
out_dtype
));
TORCH_CHECK
((
out
.
size
(
1
)
*
out
.
element_size
())
%
16
==
0
,
"out must be multiple of 16 bytes for memory alignment"
);
auto
sm_version
=
getSMVersion
();
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
if
(
sm_version
>=
90
)
{
if
(
out_dtype
==
torch
::
kBFloat16
)
{
sm90_fp8_blockwise_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
);
}
else
{
sm90_fp8_blockwise_dispatch_shape
<
cutlass
::
half_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
);
}
return
out
;
}
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No implemented fp8_blockwise_scaled_mm for current compute capability: "
,
sm_version
);
}
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
View file @
640363ad
...
@@ -60,6 +60,11 @@ torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat
...
@@ -60,6 +60,11 @@ torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
// fp8_blockwise_scaled_mm
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
);
// lightning_attention_decode
// lightning_attention_decode
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
...
...
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
640363ad
...
@@ -125,6 +125,16 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
...
@@ -125,6 +125,16 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
)
)
def
fp8_blockwise_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
):
return
torch
.
ops
.
sgl_kernels
.
fp8_blockwise_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
)
def
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
def
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
torch
.
ops
.
sgl_kernels
.
fp8_scaled_mm
(
return
torch
.
ops
.
sgl_kernels
.
fp8_scaled_mm
(
mat_a
,
mat_a
,
...
...
sgl-kernel/src/sgl-kernel/torch_extension.cc
View file @
640363ad
...
@@ -54,6 +54,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
...
@@ -54,6 +54,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"bias) -> Tensor"
);
"bias) -> Tensor"
);
m
.
impl
(
"fp8_scaled_mm"
,
torch
::
kCUDA
,
&
fp8_scaled_mm
);
m
.
impl
(
"fp8_scaled_mm"
,
torch
::
kCUDA
,
&
fp8_scaled_mm
);
// fp8_blockwise_scaled_mm
m
.
def
(
"fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype) -> "
"Tensor"
);
m
.
impl
(
"fp8_blockwise_scaled_mm"
,
torch
::
kCUDA
,
&
fp8_blockwise_scaled_mm
);
// lightning_attention_decode
// lightning_attention_decode
m
.
def
(
m
.
def
(
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
...
...
sgl-kernel/tests/test_fp8_blockwise_gemm.py
0 → 100644
View file @
640363ad
import
unittest
from
typing
import
Optional
,
Type
import
torch
from
sgl_kernel
import
fp8_blockwise_scaled_mm
def
cdiv
(
a
:
int
,
b
:
int
)
->
int
:
"""Ceiling division."""
return
-
(
a
//
-
b
)
def
scale_shape
(
shape
,
group_shape
):
assert
len
(
shape
)
==
len
(
group_shape
)
return
tuple
(
cdiv
(
shape
[
i
],
group_shape
[
i
])
for
i
in
range
(
len
(
group_shape
)))
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
# We treat N-dimensional group scaling as extended numpy-style broadcasting
# in numpy simply stretches dimensions with an extent of 1 to match the
# the target shape by repeating the data along that dimension (broadcasting)
# , we extend these semantics to say if the extent of a dimension in the
# source shape is not 1 and does not match the target shape we repeat each
# element along that dimension src_shape[dim] // target_shape[dim] times
# example if we have:
# a = [[1, 2], and target_shape = (2, 4)
# [3, 4]]
# then we would expand a to:
# a = [[1, 1, 2, 2],
# [3, 3, 4, 4]]
# NOTE this function this function does not explicitly broadcast dimensions
# with an extent of 1, since this can be done implicitly by pytorch
def
group_broadcast
(
t
,
shape
):
for
i
,
s
in
enumerate
(
shape
):
if
t
.
shape
[
i
]
!=
s
and
t
.
shape
[
i
]
!=
1
:
assert
s
%
t
.
shape
[
i
]
==
0
t
=
(
t
.
unsqueeze
(
i
+
1
)
.
expand
(
*
t
.
shape
[:
i
+
1
],
s
//
t
.
shape
[
i
],
*
t
.
shape
[
i
+
1
:])
.
flatten
(
i
,
i
+
1
)
)
return
t
scale_a
=
group_broadcast
(
scale_a
,
a
.
shape
)
scale_b
=
group_broadcast
(
scale_b
,
b
.
shape
)
output
=
torch
.
mm
(
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
)),
(
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
))
).
to
(
out_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
class
TestFp8Gemm
(
unittest
.
TestCase
):
def
_test_accuracy_once
(
self
,
M
,
N
,
K
,
out_dtype
,
device
):
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
a_fp32
=
(
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
)
a_fp8
=
a_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
b_fp32
=
(
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
)
b_fp8
=
b_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
).
t
()
scale_a_group_shape
=
(
1
,
128
)
scale_b_group_shape
=
(
128
,
128
)
scale_a_shape
=
scale_shape
(
a_fp8
.
shape
,
scale_a_group_shape
)
scale_b_shape
=
scale_shape
(
b_fp8
.
shape
,
scale_b_group_shape
)
scale_a
=
torch
.
randn
(
scale_a_shape
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.001
scale_b
=
torch
.
randn
(
scale_b_shape
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.001
scale_a
=
scale_a
.
t
().
contiguous
().
t
()
scale_b
=
scale_b
.
t
().
contiguous
().
t
()
o1
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
o
=
baseline_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
out_dtype
)
o1
=
fp8_blockwise_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
out_dtype
)
rtol
=
0.02
atol
=
1
torch
.
testing
.
assert_close
(
o
,
o1
,
rtol
=
rtol
,
atol
=
atol
)
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, out_dtype=
{
out_dtype
}
: OK"
)
def
test_accuracy
(
self
):
Ms
=
[
1
,
128
,
512
,
1024
,
4096
]
Ns
=
[
128
,
512
,
1024
,
4096
]
Ks
=
[
512
,
1024
,
4096
,
8192
,
16384
]
out_dtypes
=
[
torch
.
bfloat16
,
torch
.
float16
]
for
M
in
Ms
:
for
N
in
Ns
:
for
K
in
Ks
:
for
out_dtype
in
out_dtypes
:
self
.
_test_accuracy_once
(
M
,
N
,
K
,
out_dtype
,
"cuda"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment