Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
640363ad
"src/diffusers/loaders/lora_pipeline.py" did not exist on "990860911f2b15cc473d4323c8dd3d721e94eb39"
Unverified
Commit
640363ad
authored
Feb 13, 2025
by
yizhang2077
Committed by
GitHub
Feb 13, 2025
Browse files
support blockwise fp8 matmul kernel (#3267)
parent
8616357a
Changes
11
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1366 additions
and
0 deletions
+1366
-0
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
+148
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+2
-0
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
...cutlass_extensions/gemm/collective/collective_builder.hpp
+125
-0
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
...mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
+733
-0
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
...l-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
+33
-0
sgl-kernel/src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu
+191
-0
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
+5
-0
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+10
-0
sgl-kernel/src/sgl-kernel/torch_extension.cc
sgl-kernel/src/sgl-kernel/torch_extension.cc
+6
-0
sgl-kernel/tests/test_fp8_blockwise_gemm.py
sgl-kernel/tests/test_fp8_blockwise_gemm.py
+112
-0
No files found.
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
0 → 100644
View file @
640363ad
import
argparse
import
copy
import
itertools
import
torch
import
triton
from
sgl_kernel
import
fp8_blockwise_scaled_mm
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
def
get_weight_shapes
(
args
):
models_tps
=
list
(
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
))
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
# cannot TP
total
=
[
# (512 + 64, 7168), # this weight is not supported by current kernel
((
128
+
64
)
*
128
,
7168
),
(
128
*
(
128
+
128
),
512
),
(
7168
,
16384
),
(
7168
,
18432
),
]
# N can TP
n_tp
=
[
(
18432
*
2
,
7168
),
((
128
+
64
)
*
128
,
7168
),
(
128
*
(
128
+
128
),
512
),
(
24576
,
1536
),
(
4096
,
7168
),
]
# K can TP
k_tp
=
[(
7168
,
18432
),
(
7168
,
16384
),
(
7168
,
2048
)]
# only support Deepseek-V3
SUPPORT_MODEL
=
[
"deepseek-ai/DeepSeek-V3"
]
weight_shapes
=
[]
for
model
,
tp_size
in
models_tps
:
assert
model
in
SUPPORT_MODEL
for
t
in
total
:
new_t
=
[
t
[
0
],
t
[
1
],
model
]
weight_shapes
.
append
(
new_t
)
for
n_t
in
n_tp
:
new_t
=
[
n_t
[
0
]
//
tp_size
,
n_t
[
1
],
model
]
weight_shapes
.
append
(
new_t
)
for
k_t
in
k_tp
:
new_t
=
[
k_t
[
0
],
k_t
[
1
]
//
tp_size
,
model
]
weight_shapes
.
append
(
new_t
)
return
weight_shapes
def
cdiv
(
a
:
int
,
b
:
int
)
->
int
:
"""Ceiling division."""
return
-
(
a
//
-
b
)
def
scale_shape
(
shape
,
group_shape
):
assert
len
(
shape
)
==
len
(
group_shape
)
return
tuple
(
cdiv
(
shape
[
i
],
group_shape
[
i
])
for
i
in
range
(
len
(
group_shape
)))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm"
,
"sgl-kernel"
],
line_names
=
[
"vllm fp8 blockwise gemm"
,
"sgl-kernel fp8 blockwise gemm"
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
)],
ylabel
=
"GB/s"
,
plot_name
=
"fp8 blockwise scaled matmul"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
M
=
batch_size
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
a_fp32
=
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
-
0.5
)
*
2
*
fp8_max
a_fp8
=
a_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
b_fp32
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
-
0.5
)
*
2
*
fp8_max
b_fp8
=
b_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
).
t
()
scale_a_group_shape
=
(
1
,
128
)
scale_b_group_shape
=
(
128
,
128
)
scale_a_shape
=
scale_shape
(
a_fp8
.
shape
,
scale_a_group_shape
)
scale_b_shape
=
scale_shape
(
b_fp8
.
shape
,
scale_b_group_shape
)
scale_a
=
torch
.
randn
(
scale_a_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
randn
(
scale_b_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_a
=
scale_a
.
t
().
contiguous
().
t
()
scale_b
=
scale_b
.
t
().
contiguous
().
t
()
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"sgl-kernel"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
fp8_blockwise_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
torch
.
float16
),
quantiles
=
quantiles
,
)
if
provider
==
"vllm"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
torch
.
float16
),
quantiles
=
quantiles
,
)
gbps
=
(
lambda
ms
:
(
(
2
*
M
*
N
*
K
-
M
*
N
)
*
a_fp8
.
element_size
()
+
(
3
*
M
*
N
)
*
scale_a
.
element_size
()
)
*
1e-9
/
(
ms
*
1e-3
)
)
return
gbps
(
ms
),
gbps
(
max_ms
),
gbps
(
min_ms
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"deepseek-ai/DeepSeek-V3"
],
help
=
"List of models to benchmark"
,
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
],
help
=
"List of tensor parallel sizes"
,
)
args
=
parser
.
parse_args
()
NK_model_names
=
get_weight_shapes
(
args
)
for
N
,
K
,
model_name
in
NK_model_names
:
print
(
f
"
{
model_name
}
N=
{
N
}
K=
{
K
}
: "
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_fp8_blockwise_res"
,
N
=
N
,
K
=
K
,
)
print
(
"Benchmark finished!"
)
sgl-kernel/setup.py
View file @
640363ad
...
@@ -96,6 +96,7 @@ sources = [
...
@@ -96,6 +96,7 @@ sources = [
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu"
,
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu"
,
"src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu"
,
"src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu"
,
"src/sgl-kernel/csrc/eagle_utils.cu"
,
"src/sgl-kernel/csrc/eagle_utils.cu"
,
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
640363ad
...
@@ -14,6 +14,7 @@ from sgl_kernel.ops import (
...
@@ -14,6 +14,7 @@ from sgl_kernel.ops import (
build_tree_kernel_efficient
,
build_tree_kernel_efficient
,
custom_dispose
,
custom_dispose
,
custom_reduce
,
custom_reduce
,
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
,
fp8_scaled_mm
,
fused_add_rmsnorm
,
fused_add_rmsnorm
,
gelu_and_mul
,
gelu_and_mul
,
...
@@ -44,6 +45,7 @@ __all__ = [
...
@@ -44,6 +45,7 @@ __all__ = [
"bmm_fp8"
,
"bmm_fp8"
,
"custom_dispose"
,
"custom_dispose"
,
"custom_reduce"
,
"custom_reduce"
,
"fp8_blockwise_scaled_mm"
,
"fp8_scaled_mm"
,
"fp8_scaled_mm"
,
"fused_add_rmsnorm"
,
"fused_add_rmsnorm"
,
"gelu_and_mul"
,
"gelu_and_mul"
,
...
...
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
0 → 100644
View file @
640363ad
// Adapt from
// https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/collective/collective_buildler.hpp
// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl
// clang-format off
#pragma once
#include <cutlass/gemm/collective/builders/sm90_gmma_builder.inl>
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
::
gemm
::
collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_SS (BlockScaled Builders)
template
<
class
ElementA
,
class
GmemLayoutATag
,
int
AlignmentA
,
class
ElementB
,
class
GmemLayoutBTag
,
int
AlignmentB
,
class
ElementAccumulator
,
class
TileShape_MNK
,
class
ClusterShape_MNK
,
class
StageCountType
,
int
ScaleGranularityM
>
struct
CollectiveBuilder
<
arch
::
Sm90
,
arch
::
OpClassTensorOp
,
ElementA
,
GmemLayoutATag
,
AlignmentA
,
ElementB
,
GmemLayoutBTag
,
AlignmentB
,
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
StageCountType
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>
,
cute
::
enable_if_t
<
not
detail
::
is_use_rmem_A
<
ElementA
,
GmemLayoutATag
,
ElementB
,
GmemLayoutBTag
>
()
>
>
{
using
KernelScheduleType
=
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>
;
static_assert
(
is_static
<
TileShape_MNK
>::
value
);
static_assert
(
is_static
<
ClusterShape_MNK
>::
value
);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert
(
cutlass
::
detail
::
dependent_false
<
ElementA
>
,
"Unsupported Toolkit for SM90 Collective Builder
\n
"
);
#endif
static_assert
(
detail
::
is_aligned
<
ElementA
,
AlignmentA
,
ElementB
,
AlignmentB
,
detail
::
tma_alignment_bytes
>
(),
"Should meet TMA alignment requirement
\n
"
);
static
constexpr
bool
IsArrayOfPointersGemm
=
(
cute
::
is_any_of_v
<
KernelScheduleType
,
KernelPtrArrayTmaWarpSpecializedCooperative
,
KernelPtrArrayTmaWarpSpecializedPingpong
>
);
static
constexpr
bool
IsFP8Input
=
detail
::
is_input_fp8
<
ElementA
,
ElementB
>
();
static_assert
((
!
IsFP8Input
||
!
IsArrayOfPointersGemm
),
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now."
);
// For fp32 types, map to tf32 MMA value type
using
ElementAMma
=
cute
::
conditional_t
<
cute
::
is_same_v
<
ElementA
,
float
>
,
tfloat32_t
,
ElementA
>
;
using
ElementBMma
=
cute
::
conditional_t
<
cute
::
is_same_v
<
ElementB
,
float
>
,
tfloat32_t
,
ElementB
>
;
static
constexpr
cute
::
GMMA
::
Major
GmmaMajorA
=
detail
::
gmma_ss_tag_to_major_A
<
ElementAMma
,
GmemLayoutATag
>
();
static
constexpr
cute
::
GMMA
::
Major
GmmaMajorB
=
detail
::
gmma_ss_tag_to_major_B
<
ElementBMma
,
GmemLayoutBTag
>
();
static
constexpr
bool
IsCooperative
=
cute
::
is_any_of_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedCooperative
,
KernelPtrArrayTmaWarpSpecializedCooperative
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>>
;
using
AtomLayoutMNK
=
cute
::
conditional_t
<
IsCooperative
,
Layout
<
Shape
<
_2
,
_1
,
_1
>>
,
Layout
<
Shape
<
_1
,
_1
,
_1
>>>
;
using
TiledMma
=
decltype
(
cute
::
make_tiled_mma
(
cute
::
GMMA
::
ss_op_selector
<
ElementAMma
,
ElementBMma
,
ElementAccumulator
,
TileShape_MNK
,
GmmaMajorA
,
GmmaMajorB
>
(),
AtomLayoutMNK
{}));
using
GmemTiledCopyA
=
decltype
(
detail
::
sm90_cluster_shape_to_tma_atom
(
shape
<
1
>
(
ClusterShape_MNK
{})));
using
GmemTiledCopyB
=
decltype
(
detail
::
sm90_cluster_shape_to_tma_atom
(
shape
<
0
>
(
ClusterShape_MNK
{})));
using
SmemLayoutAtomA
=
decltype
(
detail
::
ss_smem_selector
<
GmmaMajorA
,
ElementAMma
,
decltype
(
cute
::
get
<
0
>
(
TileShape_MNK
{})),
decltype
(
cute
::
get
<
2
>
(
TileShape_MNK
{}))
>
());
using
SmemLayoutAtomB
=
decltype
(
detail
::
ss_smem_selector
<
GmmaMajorB
,
ElementBMma
,
decltype
(
cute
::
get
<
1
>
(
TileShape_MNK
{})),
decltype
(
cute
::
get
<
2
>
(
TileShape_MNK
{}))
>
());
static
constexpr
size_t
TensorMapStorage
=
IsArrayOfPointersGemm
?
sizeof
(
cute
::
TmaDescriptor
)
*
2
/* for A and B */
:
0
;
static
constexpr
int
KernelSmemCarveout
=
static_cast
<
int
>
(
TensorMapStorage
);
static
constexpr
int
PipelineStages
=
detail
::
compute_stage_count_or_override
<
detail
::
sm90_smem_capacity_bytes
-
KernelSmemCarveout
,
ElementAMma
,
ElementBMma
,
TileShape_MNK
>
(
StageCountType
{});
using
DispatchPolicy
=
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
<
PipelineStages
,
ClusterShape_MNK
,
KernelScheduleType
,
ScaleGranularityM
>
;
using
SmemCopyAtomA
=
void
;
using
SmemCopyAtomB
=
void
;
using
CollectiveOp
=
CollectiveMma
<
DispatchPolicy
,
TileShape_MNK
,
ElementA
,
TagToStrideA_t
<
GmemLayoutATag
>
,
ElementB
,
TagToStrideB_t
<
GmemLayoutBTag
>
,
TiledMma
,
GmemTiledCopyA
,
SmemLayoutAtomA
,
SmemCopyAtomA
,
cute
::
identity
,
GmemTiledCopyB
,
SmemLayoutAtomB
,
SmemCopyAtomB
,
cute
::
identity
>
;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
0 → 100644
View file @
640363ad
This diff is collapsed.
Click to expand it.
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
0 → 100644
View file @
640363ad
// Adapt from https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
#pragma once
#include <cutlass/gemm/dispatch_policy.hpp>
namespace
cutlass
::
gemm
{
//////////////////////////////////////////////////////////////////////////////
// FP8 related policies (including Blocked Scaled Accumulation)
// `ScaleGranularityM` specifies scaling granularity along M, while zero-value
// `ScaleGranularityM` indicates that scaling granularity is
// `size<0>(TileShape_MNK{})` along M.
template
<
int
ScaleGranularityM
=
0
>
struct
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
:
KernelTmaWarpSpecializedCooperative
{};
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
// specialized dynamic schedule For FP8 kernels with Block Scaling
template
<
int
Stages_
,
class
ClusterShape_
=
Shape
<
_1
,
_1
,
_1
>,
class
KernelSchedule
=
KernelTmaWarpSpecialized
,
int
ScaleGranularityM
=
0
// `ScaleGranularityM` specifies scaling granularity along M,
// while zero-value `ScaleGranularityM` indicates that scaling
// granularity is `size<0>(TileShape_MNK{})` along M.
>
struct
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
:
MainloopSm90TmaGmmaWarpSpecialized
<
Stages_
,
ClusterShape_
,
KernelSchedule
>
{
static_assert
(
cute
::
is_same_v
<
KernelSchedule
,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>>
,
"KernelSchedule must be one of the warp specialized policies"
);
};
//////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::gemm
sgl-kernel/src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu
0 → 100644
View file @
640363ad
#include <ATen/cuda/CUDAContext.h>
#include <cudaTypedefs.h>
#include <cutlass/arch/arch.h>
#include <cutlass/arch/memory.h>
#include <cutlass/arch/mma.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/thread/activation.h>
#include <cutlass/epilogue/thread/linear_combination.h>
#include <cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
#include <cutlass/gemm/thread/mma.h>
#include <cutlass/layout/matrix.h>
#include <cutlass/matrix_coord.h>
#include <cutlass/numeric_types.h>
#include <cutlass/tensor_ref.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <cutlass/epilogue/collective/collective_builder.hpp>
#include <cutlass/epilogue/collective/default_epilogue.hpp>
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
#include <cutlass/gemm/collective/collective_builder.hpp>
#include <cutlass/gemm/dispatch_policy.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "utils.h"
using
namespace
cute
;
template
<
typename
OutType
,
typename
TileShape
,
typename
ClusterShape
,
int
ScaleGranularityM
=
1
>
void
launch_sm90_fp8_blockwise_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
using
ElementAccumulator
=
float
;
using
ElementCompute
=
float
;
using
ElementBlockScale
=
float
;
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
using
ElementB
=
cutlass
::
float_e4m3_t
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
using
ElementC
=
void
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
OutType
>::
value
;
using
ElementD
=
OutType
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
constexpr
int
AlignmentD
=
AlignmentC
;
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
EpilogueTileType
=
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
;
using
StoreEpilogueCompute
=
typename
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
>
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
ScaleGranularityM
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
TileShape
,
ClusterShape
,
EpilogueTileType
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
EpilogueSchedule
,
StoreEpilogueCompute
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
,
AlignmentA
,
ElementB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
Gemm
gemm_op
;
int
m
=
a
.
size
(
0
);
int
k
=
a
.
size
(
1
);
int
n
=
b
.
size
(
1
);
auto
a_ptr
=
static_cast
<
ElementA
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementB
*>
(
b
.
data_ptr
());
auto
o_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
auto
a_s_ptr
=
static_cast
<
ElementBlockScale
*>
(
scales_a
.
data_ptr
());
auto
b_s_ptr
=
static_cast
<
ElementBlockScale
*>
(
scales_b
.
data_ptr
());
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideC
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
StrideD
;
StrideA
stride_a
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
m
,
k
,
1
));
StrideB
stride_b
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
n
,
k
,
1
));
StrideC
stride_c
;
StrideD
stride_d
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
cute
::
make_shape
(
m
,
n
,
1
));
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
stride_a
,
b_ptr
,
stride_b
,
4
,
a_s_ptr
,
b_s_ptr
};
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{{},
nullptr
,
stride_d
,
o_ptr
,
stride_d
};
typename
Gemm
::
Arguments
args
=
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
{
m
,
n
,
k
,
1
},
mainloop_args
,
epilogue_args
,
};
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
can_implement
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
can_implement
))
auto
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
status
))
}
template
<
typename
OutType
>
void
sm90_fp8_blockwise_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
launch_sm90_fp8_blockwise_scaled_mm
<
OutType
,
TileShape
,
ClusterShape
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
)
{
TORCH_CHECK
(
mat_a
.
is_cuda
(),
"mat_a must be a CUDA tensor"
);
TORCH_CHECK
(
mat_b
.
is_cuda
(),
"mat_b must be a CUDA tensor"
);
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
TORCH_CHECK
(
mat_b
.
dim
()
==
2
,
"mat_b must be a 2D tensor"
);
TORCH_CHECK
(
mat_a
.
stride
(
1
)
==
1
,
"mat_a must be a row major tensor"
);
TORCH_CHECK
(
mat_b
.
stride
(
0
)
==
1
,
"mat_a must be a column major tensor"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
0
),
"mat_a and mat_b shapes cannot be multiplied"
);
TORCH_CHECK
((
mat_a
.
size
(
1
)
*
mat_a
.
element_size
())
%
16
==
0
,
"mat_a must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
((
mat_b
.
size
(
0
)
*
mat_b
.
element_size
())
%
16
==
0
,
"mat_b must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
(
mat_a
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_a must be Float8_e4m3fn"
);
TORCH_CHECK
(
mat_b
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_b must be Float8_e4m3fn"
);
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
||
out_dtype
==
torch
::
kBFloat16
,
"out_dtype must be Half or BFloat16"
);
auto
is_contiguous_vector
=
[](
const
torch
::
Tensor
&
t
)
{
auto
t_sizes
=
t
.
sizes
();
return
t
.
is_contiguous
()
&&
(
t
.
dim
()
==
1
||
(
t
.
dim
()
==
2
&&
*
std
::
min_element
(
t_sizes
.
begin
(),
t_sizes
.
end
())
==
1
));
};
TORCH_CHECK
(
mat_a
.
size
(
0
)
==
scales_a
.
size
(
0
),
"size of scales_a is not matched"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
/
128
==
scales_a
.
size
(
1
),
"size of scales_a is not matched"
);
TORCH_CHECK
(
scales_a
.
stride
(
0
)
==
1
||
is_contiguous_vector
(
scales_a
),
"scales_a must be M major"
);
TORCH_CHECK
(
mat_b
.
size
(
0
)
/
128
==
scales_b
.
size
(
0
),
"size of scales_b is not matched"
);
TORCH_CHECK
(
mat_b
.
size
(
1
)
/
128
==
scales_b
.
size
(
1
),
"size of scales_b is not matched"
);
TORCH_CHECK
(
scales_b
.
stride
(
0
)
==
1
||
is_contiguous_vector
(
scales_b
),
"scales_b must be K major"
);
TORCH_CHECK
(
scales_a
.
scalar_type
()
==
torch
::
kFloat32
,
"scales_a must be Float32"
);
TORCH_CHECK
(
scales_b
.
scalar_type
()
==
torch
::
kFloat32
,
"scales_b must be Float32"
);
torch
::
Tensor
out
=
torch
::
empty
({
mat_a
.
size
(
0
),
mat_b
.
size
(
1
)},
mat_a
.
options
().
dtype
(
out_dtype
));
TORCH_CHECK
((
out
.
size
(
1
)
*
out
.
element_size
())
%
16
==
0
,
"out must be multiple of 16 bytes for memory alignment"
);
auto
sm_version
=
getSMVersion
();
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
if
(
sm_version
>=
90
)
{
if
(
out_dtype
==
torch
::
kBFloat16
)
{
sm90_fp8_blockwise_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
);
}
else
{
sm90_fp8_blockwise_dispatch_shape
<
cutlass
::
half_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
);
}
return
out
;
}
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No implemented fp8_blockwise_scaled_mm for current compute capability: "
,
sm_version
);
}
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
View file @
640363ad
...
@@ -60,6 +60,11 @@ torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat
...
@@ -60,6 +60,11 @@ torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
// fp8_blockwise_scaled_mm
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
);
// lightning_attention_decode
// lightning_attention_decode
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
...
...
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
640363ad
...
@@ -125,6 +125,16 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
...
@@ -125,6 +125,16 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
)
)
def
fp8_blockwise_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
):
return
torch
.
ops
.
sgl_kernels
.
fp8_blockwise_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
)
def
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
def
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
torch
.
ops
.
sgl_kernels
.
fp8_scaled_mm
(
return
torch
.
ops
.
sgl_kernels
.
fp8_scaled_mm
(
mat_a
,
mat_a
,
...
...
sgl-kernel/src/sgl-kernel/torch_extension.cc
View file @
640363ad
...
@@ -54,6 +54,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
...
@@ -54,6 +54,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"bias) -> Tensor"
);
"bias) -> Tensor"
);
m
.
impl
(
"fp8_scaled_mm"
,
torch
::
kCUDA
,
&
fp8_scaled_mm
);
m
.
impl
(
"fp8_scaled_mm"
,
torch
::
kCUDA
,
&
fp8_scaled_mm
);
// fp8_blockwise_scaled_mm
m
.
def
(
"fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype) -> "
"Tensor"
);
m
.
impl
(
"fp8_blockwise_scaled_mm"
,
torch
::
kCUDA
,
&
fp8_blockwise_scaled_mm
);
// lightning_attention_decode
// lightning_attention_decode
m
.
def
(
m
.
def
(
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
...
...
sgl-kernel/tests/test_fp8_blockwise_gemm.py
0 → 100644
View file @
640363ad
import
unittest
from
typing
import
Optional
,
Type
import
torch
from
sgl_kernel
import
fp8_blockwise_scaled_mm
def
cdiv
(
a
:
int
,
b
:
int
)
->
int
:
"""Ceiling division."""
return
-
(
a
//
-
b
)
def
scale_shape
(
shape
,
group_shape
):
assert
len
(
shape
)
==
len
(
group_shape
)
return
tuple
(
cdiv
(
shape
[
i
],
group_shape
[
i
])
for
i
in
range
(
len
(
group_shape
)))
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
# We treat N-dimensional group scaling as extended numpy-style broadcasting
# in numpy simply stretches dimensions with an extent of 1 to match the
# the target shape by repeating the data along that dimension (broadcasting)
# , we extend these semantics to say if the extent of a dimension in the
# source shape is not 1 and does not match the target shape we repeat each
# element along that dimension src_shape[dim] // target_shape[dim] times
# example if we have:
# a = [[1, 2], and target_shape = (2, 4)
# [3, 4]]
# then we would expand a to:
# a = [[1, 1, 2, 2],
# [3, 3, 4, 4]]
# NOTE this function this function does not explicitly broadcast dimensions
# with an extent of 1, since this can be done implicitly by pytorch
def
group_broadcast
(
t
,
shape
):
for
i
,
s
in
enumerate
(
shape
):
if
t
.
shape
[
i
]
!=
s
and
t
.
shape
[
i
]
!=
1
:
assert
s
%
t
.
shape
[
i
]
==
0
t
=
(
t
.
unsqueeze
(
i
+
1
)
.
expand
(
*
t
.
shape
[:
i
+
1
],
s
//
t
.
shape
[
i
],
*
t
.
shape
[
i
+
1
:])
.
flatten
(
i
,
i
+
1
)
)
return
t
scale_a
=
group_broadcast
(
scale_a
,
a
.
shape
)
scale_b
=
group_broadcast
(
scale_b
,
b
.
shape
)
output
=
torch
.
mm
(
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
)),
(
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
))
).
to
(
out_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
class
TestFp8Gemm
(
unittest
.
TestCase
):
def
_test_accuracy_once
(
self
,
M
,
N
,
K
,
out_dtype
,
device
):
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
a_fp32
=
(
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
)
a_fp8
=
a_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
b_fp32
=
(
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
)
b_fp8
=
b_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
).
t
()
scale_a_group_shape
=
(
1
,
128
)
scale_b_group_shape
=
(
128
,
128
)
scale_a_shape
=
scale_shape
(
a_fp8
.
shape
,
scale_a_group_shape
)
scale_b_shape
=
scale_shape
(
b_fp8
.
shape
,
scale_b_group_shape
)
scale_a
=
torch
.
randn
(
scale_a_shape
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.001
scale_b
=
torch
.
randn
(
scale_b_shape
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.001
scale_a
=
scale_a
.
t
().
contiguous
().
t
()
scale_b
=
scale_b
.
t
().
contiguous
().
t
()
o1
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
o
=
baseline_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
out_dtype
)
o1
=
fp8_blockwise_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
out_dtype
)
rtol
=
0.02
atol
=
1
torch
.
testing
.
assert_close
(
o
,
o1
,
rtol
=
rtol
,
atol
=
atol
)
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, out_dtype=
{
out_dtype
}
: OK"
)
def
test_accuracy
(
self
):
Ms
=
[
1
,
128
,
512
,
1024
,
4096
]
Ns
=
[
128
,
512
,
1024
,
4096
]
Ks
=
[
512
,
1024
,
4096
,
8192
,
16384
]
out_dtypes
=
[
torch
.
bfloat16
,
torch
.
float16
]
for
M
in
Ms
:
for
N
in
Ns
:
for
K
in
Ks
:
for
out_dtype
in
out_dtypes
:
self
.
_test_accuracy_once
(
M
,
N
,
K
,
out_dtype
,
"cuda"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment