Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
82392da8
Unverified
Commit
82392da8
authored
Jan 26, 2025
by
HandH1998
Committed by
GitHub
Jan 26, 2025
Browse files
support w8a8 fp8 kernel with CUTLASS (#3047)
Co-authored-by:
yych0745
<
1398089567@qq.com
>
parent
95f789ad
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
881 additions
and
0 deletions
+881
-0
sgl-kernel/benchmark/bench_fp8_gemm.py
sgl-kernel/benchmark/bench_fp8_gemm.py
+164
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+2
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+2
-0
sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu
+624
-0
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
+5
-0
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+11
-0
sgl-kernel/src/sgl-kernel/torch_extension.cc
sgl-kernel/src/sgl-kernel/torch_extension.cc
+6
-0
sgl-kernel/tests/test_fp8_gemm.py
sgl-kernel/tests/test_fp8_gemm.py
+67
-0
No files found.
sgl-kernel/benchmark/bench_fp8_gemm.py
0 → 100644
View file @
82392da8
import
argparse
import
copy
import
itertools
import
torch
import
triton
from
sgl_kernel
import
fp8_scaled_mm
as
sgl_scaled_mm
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
from
vllm._custom_ops
import
scaled_fp8_quant
as
vllm_scaled_fp8_quant
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536
# TP1 shapes
WEIGHT_SHAPES
=
{
"meta-llama/Llama-3.1-8B-Instruct"
:
[
([
4096
,
6144
],
1
),
([
4096
,
4096
],
0
),
([
4096
,
28672
],
1
),
([
14336
,
4096
],
0
),
],
"meta-llama/Llama-3.3-70B-Instruct"
:
[
([
8192
,
10240
],
1
),
([
8192
,
8192
],
0
),
([
8192
,
57344
],
1
),
([
28672
,
8192
],
0
),
],
"mistralai/Mistral-Large-Instruct-2407"
:
[
([
12288
,
14336
],
1
),
([
12288
,
12288
],
0
),
([
12288
,
57344
],
1
),
([
28672
,
12288
],
0
),
],
"Qwen/Qwen2.5-7B-Instruct"
:
[
([
3584
,
4608
],
1
),
([
3584
,
3584
],
0
),
([
3584
,
37888
],
1
),
([
18944
,
3584
],
0
),
],
"Qwen/Qwen2.5-32B-Instruct"
:
[
([
5120
,
7168
],
1
),
([
5120
,
5120
],
0
),
([
5120
,
55296
],
1
),
([
27648
,
5120
],
0
),
],
"Qwen/Qwen2.5-72B-Instruct"
:
[
([
8192
,
10240
],
1
),
([
8192
,
8192
],
0
),
([
8192
,
59136
],
1
),
([
29568
,
8192
],
0
),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
:
[
([
2048
,
3072
],
1
),
([
2048
,
4096
],
1
),
([
2048
,
2048
],
0
),
([
2048
,
576
],
0
),
([
2048
,
21888
],
1
),
([
10944
,
2048
],
0
),
([
2048
,
2816
],
1
),
([
1408
,
2048
],
0
),
],
}
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm-fp8-fp16"
,
"vllm-fp8-bf16"
,
"sglang-fp8-fp16"
,
"sglang-fp8-bf16"
,
],
line_names
=
[
"vllm-fp8-fp16"
,
"vllm-fp8-bf16"
,
"sglang-fp8-fp16"
,
"sglang-fp8-bf16"
,
],
styles
=
[(
"green"
,
"-"
),
(
"green"
,
"--"
),
(
"blue"
,
"-"
),
(
"blue"
,
"--"
)],
ylabel
=
"GB/s"
,
plot_name
=
"fp8 scaled matmul"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
# M, N, K = batch_size, 4096, 8192
M
=
batch_size
a
=
torch
.
ones
((
M
,
K
),
device
=
"cuda"
)
*
5.0
b
=
torch
.
ones
((
N
,
K
),
device
=
"cuda"
)
*
5.0
scale_a
=
torch
.
randn
((
M
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
randn
((
N
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
b_fp8
=
b_fp8
.
t
()
quantiles
=
[
0.5
,
0.2
,
0.8
]
dtype
=
torch
.
float16
if
"fp16"
in
provider
else
torch
.
bfloat16
if
"vllm-fp8"
in
provider
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
),
quantiles
=
quantiles
,
)
elif
"sglang-fp8"
in
provider
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
sgl_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
,
bias
=
None
),
quantiles
=
quantiles
,
)
gbps
=
lambda
ms
:
(
2
*
M
*
N
*
K
+
M
*
N
)
*
a
.
element_size
()
*
1e-9
/
(
ms
*
1e-3
)
return
gbps
(
ms
),
gbps
(
max_ms
),
gbps
(
min_ms
)
def
prepare_shapes
(
args
):
KN_model_names
=
[]
models_tps
=
list
(
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
))
for
model
,
tp_size
in
models_tps
:
assert
model
in
WEIGHT_SHAPES
for
KN
,
tp_split_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_split_dim
]
=
KN
[
tp_split_dim
]
//
tp_size
KN
.
append
(
model
)
KN_model_names
.
append
(
KN
)
return
KN_model_names
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.1-8B-Instruct"
],
help
=
"List of models to benchmark"
,
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
],
help
=
"List of tensor parallel sizes"
,
)
args
=
parser
.
parse_args
()
KN_model_names
=
prepare_shapes
(
args
)
for
K
,
N
,
model_name
in
KN_model_names
:
print
(
f
"
{
model_name
}
N=
{
N
}
K=
{
K
}
: "
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_fp8_res"
,
N
=
N
,
K
=
K
)
print
(
"Benchmark finished!"
)
sgl-kernel/setup.py
View file @
82392da8
...
@@ -56,6 +56,7 @@ include_dirs = [
...
@@ -56,6 +56,7 @@ include_dirs = [
turbomind
.
resolve
(),
turbomind
.
resolve
(),
turbomind
.
resolve
()
/
"src"
,
turbomind
.
resolve
()
/
"src"
,
]
]
nvcc_flags
=
[
nvcc_flags
=
[
"-DNDEBUG"
,
"-DNDEBUG"
,
f
"-DOPERATOR_NAMESPACE=
{
operator_namespace
}
"
,
f
"-DOPERATOR_NAMESPACE=
{
operator_namespace
}
"
,
...
@@ -82,6 +83,7 @@ sources = [
...
@@ -82,6 +83,7 @@ sources = [
"src/sgl-kernel/csrc/trt_reduce_kernel.cu"
,
"src/sgl-kernel/csrc/trt_reduce_kernel.cu"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu"
,
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu"
,
"src/sgl-kernel/csrc/rotary_embedding.cu"
,
"src/sgl-kernel/csrc/rotary_embedding.cu"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
82392da8
...
@@ -2,6 +2,7 @@ from sgl_kernel.ops import (
...
@@ -2,6 +2,7 @@ from sgl_kernel.ops import (
bmm_fp8
,
bmm_fp8
,
custom_dispose
,
custom_dispose
,
custom_reduce
,
custom_reduce
,
fp8_scaled_mm
,
fused_add_rmsnorm
,
fused_add_rmsnorm
,
gelu_and_mul
,
gelu_and_mul
,
gelu_tanh_and_mul
,
gelu_tanh_and_mul
,
...
@@ -27,6 +28,7 @@ __all__ = [
...
@@ -27,6 +28,7 @@ __all__ = [
"bmm_fp8"
,
"bmm_fp8"
,
"custom_dispose"
,
"custom_dispose"
,
"custom_reduce"
,
"custom_reduce"
,
"fp8_scaled_mm"
,
"fused_add_rmsnorm"
,
"fused_add_rmsnorm"
,
"gelu_and_mul"
,
"gelu_and_mul"
,
"gelu_tanh_and_mul"
,
"gelu_tanh_and_mul"
,
...
...
sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu
0 → 100644
View file @
82392da8
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_template.h
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm89.h
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm90.h
#include <ATen/cuda/CUDAContext.h>
#include <cudaTypedefs.h>
#include <cutlass/arch/arch.h>
#include <cutlass/arch/memory.h>
#include <cutlass/arch/mma.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/thread/activation.h>
#include <cutlass/epilogue/thread/linear_combination.h>
#include <cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
#include <cutlass/gemm/thread/mma.h>
#include <cutlass/layout/matrix.h>
#include <cutlass/matrix_coord.h>
#include <cutlass/numeric_types.h>
#include <cutlass/tensor_ref.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <cutlass/epilogue/collective/collective_builder.hpp>
#include <cutlass/epilogue/collective/default_epilogue.hpp>
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
#include <cutlass/gemm/collective/collective_builder.hpp>
#include <cutlass/gemm/dispatch_policy.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
#include "utils.h"
using
namespace
cute
;
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
template
<
typename
ElementType
,
typename
OutElementType
,
typename
AccumElementType
,
typename
CtaShape
,
typename
WarpShape
,
int
Stages
,
bool
WithBias
,
typename
FP8MathOperator
=
cutlass
::
arch
::
OpMultiplyAdd
,
template
<
typename
...
>
typename
EpilogueVisitor
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
,
typename
ThreadblockSwizzle
=
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
>
>
struct
DeviceGemmFp8RowwiseSm89
{
static_assert
(
std
::
is_same_v
<
ElementType
,
cutlass
::
float_e4m3_t
>
,
"ElementType must be FP8(e4m3)"
);
using
ElementA
=
ElementType
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
using
ElementB
=
ElementType
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
using
ElementC
=
OutElementType
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
using
ElementOutput
=
OutElementType
;
using
LayoutOutput
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentOutput
=
128
/
cutlass
::
sizeof_bits
<
ElementOutput
>::
value
;
using
ElementAccumulator
=
AccumElementType
;
using
ElementComputeEpilogue
=
float
;
using
ArchTag
=
cutlass
::
arch
::
Sm89
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
// Number of epilogue stages in EVT
static
constexpr
int
EVTEpilogueStages
=
1
;
using
OutputTileThreadMap
=
cutlass
::
epilogue
::
threadblock
::
OutputTileThreadLayout
<
CtaShape
,
WarpShape
,
ElementC
,
AlignmentC
,
EVTEpilogueStages
>
;
// Definition of EVT
using
accSrc
=
cutlass
::
epilogue
::
threadblock
::
VisitorAccFetch
;
using
ComputeBScale
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
ElementComputeEpilogue
,
ElementComputeEpilogue
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
bScaleSrc
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
ElementComputeEpilogue
,
Stride
<
_0
,
_1
,
_0
>>
;
using
EpilogueBScale
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeBScale
,
accSrc
,
bScaleSrc
>
;
using
ComputeAScale
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
ElementC
,
ElementComputeEpilogue
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
aScaleSrc
=
cutlass
::
epilogue
::
threadblock
::
VisitorColBroadcast
<
OutputTileThreadMap
,
ElementComputeEpilogue
,
Stride
<
_1
,
_0
,
_0
>>
;
using
EpilogueAScale
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAScale
,
EpilogueBScale
,
aScaleSrc
>
;
// With bias
using
biasSrc
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
ElementOutput
,
Stride
<
_0
,
_1
,
_0
>>
;
using
ComputeAScaleWithBias
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementC
,
ElementComputeEpilogue
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EpilogueAScaleWithBias
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAScaleWithBias
,
EpilogueBScale
,
aScaleSrc
,
biasSrc
>
;
using
dTar
=
cutlass
::
epilogue
::
threadblock
::
VisitorAuxStore
<
OutputTileThreadMap
,
ElementC
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
,
Stride
<
int64_t
,
_1
,
_0
>>
;
using
EpilogueStore
=
typename
cutlass
::
platform
::
conditional
<
WithBias
,
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
dTar
,
EpilogueAScaleWithBias
>
,
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
dTar
,
EpilogueAScale
>>::
type
;
using
EpilogueOp
=
EpilogueStore
;
using
GemmKernel
=
typename
cutlass
::
gemm
::
kernel
::
DefaultGemmWithVisitor
<
ElementA
,
LayoutA
,
cutlass
::
ComplexTransform
::
kNone
,
AlignmentA
,
ElementB
,
LayoutB
,
cutlass
::
ComplexTransform
::
kNone
,
AlignmentB
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementAccumulator
,
ElementComputeEpilogue
,
OperatorClass
,
ArchTag
,
CtaShape
,
WarpShape
,
InstructionShape
,
EpilogueOp
,
ThreadblockSwizzle
,
Stages
,
FP8MathOperator
,
EVTEpilogueStages
>::
GemmKernel
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
};
template
<
typename
Gemm
,
bool
WithBias
>
typename
Gemm
::
Arguments
prepare_sm89_fp8_args
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
using
ElementT
=
typename
Gemm
::
ElementA
;
using
ElementOutput
=
typename
Gemm
::
ElementD
;
using
ElementComputeEpilogue
=
float
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
ElementT
const
*
ptr_a
=
reinterpret_cast
<
ElementT
const
*>
(
a
.
data_ptr
());
ElementT
const
*
ptr_b
=
reinterpret_cast
<
ElementT
const
*>
(
b
.
data_ptr
());
ElementOutput
const
*
ptr_bias
=
nullptr
;
if
constexpr
(
WithBias
)
{
TORCH_CHECK
(
bias
.
has_value
())
ptr_bias
=
reinterpret_cast
<
ElementOutput
const
*>
(
bias
.
value
().
data_ptr
());
}
ElementOutput
*
ptr_d
=
reinterpret_cast
<
ElementOutput
*>
(
out
.
data_ptr
());
ElementComputeEpilogue
const
*
ptr_scales_a
=
reinterpret_cast
<
ElementComputeEpilogue
const
*>
(
scales_a
.
data_ptr
());
ElementComputeEpilogue
const
*
ptr_scales_b
=
reinterpret_cast
<
ElementComputeEpilogue
const
*>
(
scales_b
.
data_ptr
());
typename
Gemm
::
Arguments
args
(
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
// Mode
{
m
,
n
,
k
},
// Problem size
1
,
// Split-k factor
{},
// Epilogue args
ptr_a
,
// a pointer
ptr_b
,
// b pointer
nullptr
,
// c pointer (unused)
nullptr
,
// d pointer (unused)
m
*
k
,
// batch stride a (unused)
n
*
k
,
// batch stride b (unused)
m
*
n
,
// batch stride c (unused)
m
*
n
,
// batch stride d (unused)
lda
,
// stride a
ldb
,
// stride b
ldc
,
// stride c (unused)
ldc
);
// stride d (unused)
if
constexpr
(
WithBias
)
{
args
.
epilogue
=
{{
{
{},
// Accumulator
{
ptr_scales_b
,
ElementComputeEpilogue
(
0
),
{
_0
{},
_1
{},
_0
{}}},
{}
// Multiplies
},
{
ptr_scales_a
,
ElementComputeEpilogue
(
0
),
{
_1
{},
_0
{},
_0
{}}},
{
ptr_bias
,
ElementOutput
(
0
),
{
_0
{},
_1
{},
_0
{}}},
{}
// Multiplies
},
{
ptr_d
,
{
n
,
_1
{},
_0
{}}}};
}
else
{
args
.
epilogue
=
{{
{
{},
// Accumulator
{
ptr_scales_b
,
ElementComputeEpilogue
(
0
),
{
_0
{},
_1
{},
_0
{}}},
{}
// Multiplies
},
{
ptr_scales_a
,
ElementComputeEpilogue
(
0
),
{
_1
{},
_0
{},
_0
{}}},
{}
// Multiplies
},
{
ptr_d
,
{
n
,
_1
{},
_0
{}}}};
}
return
args
;
}
template
<
typename
Gemm
,
bool
WithBias
>
void
launch_sm89_fp8_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
auto
args
=
prepare_sm89_fp8_args
<
Gemm
,
WithBias
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
Gemm
gemm_op
;
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
can_implement
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement
==
cutlass
::
Status
::
kSuccess
)
auto
status
=
gemm_op
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
)
}
template
<
typename
OutType
,
typename
CtaShape
,
typename
WarpShape
,
int
Stages
>
void
sm89_fp8_dispatch_bias
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
using
ElementInput
=
cutlass
::
float_e4m3_t
;
using
ElementOutput
=
OutType
;
using
AccumElementType
=
float
;
if
(
bias
)
{
using
Gemm
=
typename
DeviceGemmFp8RowwiseSm89
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CtaShape
,
WarpShape
,
Stages
,
true
>::
Gemm
;
return
launch_sm89_fp8_scaled_mm
<
Gemm
,
true
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
using
Gemm
=
typename
DeviceGemmFp8RowwiseSm89
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CtaShape
,
WarpShape
,
Stages
,
false
>::
Gemm
;
return
launch_sm89_fp8_scaled_mm
<
Gemm
,
false
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
template
<
typename
OutType
>
void
sm89_fp8_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
n
=
out
.
size
(
1
);
if
(
m
==
1
)
{
if
(
n
<=
8192
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
7
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
5
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
else
if
(
m
<=
16
)
{
// M in (1, 16]
if
(
n
<=
8192
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
4
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
n
<=
16384
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
5
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
7
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
else
if
(
m
<=
64
)
{
// M in (16, 64]
if
(
n
<=
16384
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
7
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
7
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
else
if
(
m
<=
128
)
{
// M in (64, 128]
if
(
n
<=
8192
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
4
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
n
<=
16384
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
5
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
5
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
else
if
(
m
<=
256
)
{
// M in (128, 256]
if
(
n
<=
8192
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
64
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
5
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
n
<=
16384
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
7
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
128
>
,
4
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
else
if
(
m
<=
512
)
{
// M in (256, 512)
if
(
n
<=
16384
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
2
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
4
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
else
{
// M in (512, inf)
if
(
n
<=
8192
)
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
3
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
return
sm89_fp8_dispatch_bias
<
OutType
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
64
>
,
2
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
}
#endif
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
template
<
typename
ElementType
,
typename
OutElementType
,
typename
AccumElementType
,
typename
CTAShape
,
typename
ClusterShape
,
typename
MainloopScheduleType
,
typename
EpilogueScheduleType
,
typename
TileSchedulerType
=
void
,
bool
WithBias
=
false
>
struct
DeviceGemmFp8RowwiseSm90
{
static_assert
(
std
::
is_same_v
<
ElementType
,
cutlass
::
float_e4m3_t
>
,
"ElementType must be FP8(e4m3)"
);
// A matrix configuration
using
ElementA
=
ElementType
;
// Element type for A matrix operand
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
// Layout type for A matrix operand
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
// Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)
// B matrix configuration
using
ElementB
=
ElementType
;
// Element type for B matrix operand
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
// Layout type for B matrix operand
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
// Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using
ElementC
=
void
;
// Element type for C matrix operands
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
// Layout type for C matrix operands
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
OutElementType
>::
value
;
// Memory access granularity/alignment of C matrices in
// units of elements (up to 16 bytes)
// Output matrix configuration
using
ElementOutput
=
OutElementType
;
// Element type for output matrix operands
using
LayoutOutput
=
cutlass
::
layout
::
RowMajor
;
// Layout type for output matrix operands
static
constexpr
int
AlignmentOutput
=
128
/
cutlass
::
sizeof_bits
<
ElementOutput
>::
value
;
// // Auxiliary matrix configuration and other fusion types
// using ElementBias = float;
// Multiply-accumulate blocking/pipelining details
using
ElementAccumulator
=
AccumElementType
;
// Element type for internal accumulation
using
ElementCompute
=
float
;
// Element type for compute
using
ElementComputeEpilogue
=
float
;
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
// Tag indicating the minimum SM that supports the intended feature
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
// Operator class tag
using
TileShape
=
CTAShape
;
// Threadblock-level tile size
static
constexpr
bool
PONG
=
false
;
static
constexpr
bool
FAST_ACCUM
=
true
;
static
constexpr
bool
USE_BIAS
=
false
;
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
// Stage count maximized
// based on the tile size
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
// Kernel to launch based on the default
// setting in the Collective Builder
// Implement rowwise scaling epilogue.
using
XScale
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
0
,
TileShape
,
ElementComputeEpilogue
,
ElementComputeEpilogue
,
cute
::
Stride
<
cute
::
Int
<
1
>
,
cute
::
Int
<
0
>
,
cute
::
Int
<
0
>>>
;
using
WScale
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
,
TileShape
,
ElementComputeEpilogue
,
ElementComputeEpilogue
,
cute
::
Stride
<
cute
::
Int
<
0
>
,
cute
::
Int
<
1
>
,
cute
::
Int
<
0
>>>
;
using
Bias
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
,
TileShape
,
ElementOutput
,
ElementOutput
,
cute
::
Stride
<
cute
::
Int
<
0
>
,
cute
::
Int
<
1
>
,
cute
::
Int
<
0
>>>
;
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
ElementComputeEpilogue
,
// First stage output type.
ElementComputeEpilogue
,
// First stage input types.
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
WScale
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
ElementOutput
,
ElementComputeEpilogue
,
// Second stage input types.
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute1
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1
,
XScale
,
EVTCompute0
>
;
// With bias
using
ComputeWithBias
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementOutput
,
ElementComputeEpilogue
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeWithBias
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeWithBias
,
XScale
,
EVTCompute0
,
Bias
>
;
using
EpilogueEVT
=
typename
cutlass
::
platform
::
conditional
<
WithBias
,
EVTComputeWithBias
,
EVTCompute1
>::
type
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementComputeEpilogue
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementOutput
,
LayoutOutput
,
AlignmentOutput
,
cutlass
::
epilogue
::
TmaWarpSpecialized
,
EpilogueEVT
>::
CollectiveOp
;
using
DefaultSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
PongSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
FastDefaultSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
;
using
FastPongSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
SlowAccum
=
DefaultSchedule
;
using
FastAccum
=
FastPongSchedule
;
// Default apply Pingpong
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
,
AlignmentA
,
ElementB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
MainloopScheduleType
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
CollectiveMainloop
,
CollectiveEpilogue
,
TileSchedulerType
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
};
template
<
typename
Gemm
,
bool
WithBias
>
typename
Gemm
::
Arguments
prepare_sm90_fp8_args
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
using
ElementT
=
typename
Gemm
::
ElementA
;
using
ElementOutput
=
typename
Gemm
::
ElementD
;
using
ElementComputeEpilogue
=
float
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideC
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
StrideD
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
ElementT
const
*
ptr_a
=
reinterpret_cast
<
ElementT
const
*>
(
a
.
data_ptr
());
ElementT
const
*
ptr_b
=
reinterpret_cast
<
ElementT
const
*>
(
b
.
data_ptr
());
ElementOutput
const
*
ptr_bias
=
nullptr
;
if
constexpr
(
WithBias
)
{
TORCH_CHECK
(
bias
.
has_value
())
ptr_bias
=
reinterpret_cast
<
ElementOutput
const
*>
(
bias
.
value
().
data_ptr
());
}
ElementOutput
*
ptr_d
=
reinterpret_cast
<
ElementOutput
*>
(
out
.
data_ptr
());
ElementComputeEpilogue
const
*
ptr_scales_a
=
reinterpret_cast
<
ElementComputeEpilogue
const
*>
(
scales_a
.
data_ptr
());
ElementComputeEpilogue
const
*
ptr_scales_b
=
reinterpret_cast
<
ElementComputeEpilogue
const
*>
(
scales_b
.
data_ptr
());
StrideA
stride_a
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
make_shape
(
m
,
k
,
1
));
StrideB
stride_b
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
make_shape
(
n
,
k
,
1
));
StrideC
stride_c
;
StrideD
stride_d
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
make_shape
(
m
,
n
,
1
));
typename
Gemm
::
Arguments
args
=
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
{
m
,
n
,
k
,
1
},
{
ptr_a
,
stride_a
,
ptr_b
,
stride_b
},
{{},
// epilogue.thread
nullptr
,
stride_c
,
ptr_d
,
stride_d
}};
if
constexpr
(
WithBias
)
{
args
.
epilogue
.
thread
=
{
{
ptr_scales_a
},
{
{
ptr_scales_b
},
{},
// Accumulator
{}
// Multiplies
},
{
ptr_bias
},
{},
// Multiplies
};
}
else
{
args
.
epilogue
.
thread
=
{
{
ptr_scales_a
},
{
{
ptr_scales_b
},
{},
// Accumulator
{}
// Multiplies
},
{},
// Multiplies
};
}
return
args
;
}
template
<
typename
Gemm
,
bool
WithBias
>
void
launch_sm90_fp8_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
auto
args
=
prepare_sm90_fp8_args
<
Gemm
,
WithBias
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
Gemm
gemm_op
;
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
can_implement
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement
==
cutlass
::
Status
::
kSuccess
)
auto
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
)
}
template
<
typename
OutType
,
typename
CTAShape
,
typename
ClusterShape
,
typename
MainloopScheduleType
,
typename
TileSchedulerType
>
void
sm90_fp8_dispatch_bias
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
,
bool
fast_accum
=
true
,
bool
use_persistent
=
false
)
{
using
ElementInput
=
cutlass
::
float_e4m3_t
;
using
ElementOutput
=
OutType
;
using
AccumElementType
=
float
;
using
EpilogueScheduleType
=
cutlass
::
epilogue
::
TmaWarpSpecialized
;
if
(
bias
)
{
using
Gemm
=
typename
DeviceGemmFp8RowwiseSm90
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShape
,
ClusterShape
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
true
>::
Gemm
;
return
launch_sm90_fp8_scaled_mm
<
Gemm
,
true
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
using
Gemm
=
typename
DeviceGemmFp8RowwiseSm90
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShape
,
ClusterShape
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
false
>::
Gemm
;
return
launch_sm90_fp8_scaled_mm
<
Gemm
,
false
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
template
<
typename
OutType
>
void
sm90_fp8_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
uint32_t
const
m
=
a
.
size
(
0
);
using
FastPingpongScheduler
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
FastBasicScheduler
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
;
using
PersistentTileScheduler
=
cutlass
::
gemm
::
PersistentScheduler
;
using
BasicTileScheduler
=
void
;
if
(
m
<=
1
)
{
return
sm90_fp8_dispatch_bias
<
OutType
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_1
,
_8
,
_1
>
,
FastBasicScheduler
,
BasicTileScheduler
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
if
(
m
<=
64
)
{
// m in [1, 64]
return
sm90_fp8_dispatch_bias
<
OutType
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_1
,
_4
,
_1
>
,
FastPingpongScheduler
,
PersistentTileScheduler
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
m
<=
256
)
{
// m in (64, 256]
return
sm90_fp8_dispatch_bias
<
OutType
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_1
,
_1
,
_1
>
,
FastPingpongScheduler
,
PersistentTileScheduler
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
m
<=
1024
)
{
// m in (256, 1024]
return
sm90_fp8_dispatch_bias
<
OutType
,
Shape
<
_128
,
_128
,
_128
>
,
Shape
<
_1
,
_1
,
_1
>
,
FastPingpongScheduler
,
PersistentTileScheduler
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
// m in (1024, inf)
return
sm90_fp8_dispatch_bias
<
OutType
,
Shape
<
_128
,
_128
,
_128
>
,
Shape
<
_2
,
_1
,
_1
>
,
FastPingpongScheduler
,
PersistentTileScheduler
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
#endif
torch
::
Tensor
fp8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
TORCH_CHECK
(
mat_a
.
is_cuda
(),
"mat_a must be a CUDA tensor"
);
TORCH_CHECK
(
mat_b
.
is_cuda
(),
"mat_b must be a CUDA tensor"
);
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
TORCH_CHECK
(
mat_b
.
dim
()
==
2
,
"mat_b must be a 2D tensor"
);
TORCH_CHECK
(
mat_a
.
stride
(
1
)
==
1
,
"mat_a must be a row major tensor"
);
TORCH_CHECK
(
mat_b
.
stride
(
0
)
==
1
,
"mat_a must be a column major tensor"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
0
),
"mat_a and mat_b shapes cannot be multiplied"
);
TORCH_CHECK
((
mat_a
.
size
(
1
)
*
mat_a
.
element_size
())
%
16
==
0
,
"mat_a must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
((
mat_b
.
size
(
0
)
*
mat_b
.
element_size
())
%
16
==
0
,
"mat_b must be multiple of 16 bytes for memory alignment"
);
TORCH_CHECK
(
mat_a
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_a must be Float8_e4m3fn"
);
TORCH_CHECK
(
mat_b
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"mat_b must be Float8_e4m3fn"
);
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
||
out_dtype
==
torch
::
kBFloat16
,
"out_dtype must be Half or BFloat16"
);
TORCH_CHECK
(
scales_a
.
numel
()
==
mat_a
.
size
(
0
),
"size of scales_a is not matched"
);
TORCH_CHECK
(
scales_b
.
numel
()
==
mat_b
.
size
(
1
),
"size of scales_b is not matched"
);
TORCH_CHECK
(
scales_a
.
is_contiguous
(),
"scales_a must be contiguous"
);
TORCH_CHECK
(
scales_b
.
is_contiguous
(),
"scales_b msut be contiguous"
);
TORCH_CHECK
(
scales_a
.
scalar_type
()
==
torch
::
kFloat32
,
"scales_a must be Float32"
);
TORCH_CHECK
(
scales_b
.
scalar_type
()
==
torch
::
kFloat32
,
"scales_b must be Float32"
);
if
(
bias
)
{
TORCH_CHECK
(
bias
->
numel
()
==
mat_b
.
size
(
1
),
"size of bias is not matched"
);
TORCH_CHECK
(
bias
->
is_contiguous
(),
"bias must be contiguous"
);
TORCH_CHECK
(
bias
->
dtype
()
==
out_dtype
,
"bias dtype must match output dtype"
);
}
torch
::
Tensor
out
=
torch
::
empty
({
mat_a
.
size
(
0
),
mat_b
.
size
(
1
)},
mat_a
.
options
().
dtype
(
out_dtype
));
TORCH_CHECK
((
out
.
size
(
1
)
*
out
.
element_size
())
%
16
==
0
,
"out must be multiple of 16 bytes for memory alignment"
);
auto
sm_version
=
getSMVersion
();
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
if
(
sm_version
>=
90
)
{
if
(
out_dtype
==
torch
::
kBFloat16
)
{
sm90_fp8_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
sm90_fp8_dispatch_shape
<
cutlass
::
half_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
return
out
;
}
#endif
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
if
(
sm_version
==
89
)
{
if
(
out_dtype
==
torch
::
kBFloat16
)
{
sm89_fp8_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
sm89_fp8_dispatch_shape
<
cutlass
::
half_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
return
out
;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No implemented fp8_scaled_mm for current compute capability: "
,
sm_version
);
}
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
View file @
82392da8
...
@@ -40,6 +40,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
...
@@ -40,6 +40,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
// fp8_scaled_mm
torch
::
Tensor
fp8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
// lightning_attention_decode
// lightning_attention_decode
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
...
...
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
82392da8
...
@@ -71,6 +71,17 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
...
@@ -71,6 +71,17 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
)
)
def
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
torch
.
ops
.
sgl_kernels
.
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
,
)
def
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
):
def
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
):
torch
.
ops
.
sgl_kernels
.
lightning_attention_decode
(
torch
.
ops
.
sgl_kernels
.
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
...
...
sgl-kernel/src/sgl-kernel/torch_extension.cc
View file @
82392da8
...
@@ -34,6 +34,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
...
@@ -34,6 +34,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"bias) -> Tensor"
);
"bias) -> Tensor"
);
m
.
impl
(
"int8_scaled_mm"
,
torch
::
kCUDA
,
&
int8_scaled_mm
);
m
.
impl
(
"int8_scaled_mm"
,
torch
::
kCUDA
,
&
int8_scaled_mm
);
// fp8_scaled_mm
m
.
def
(
"fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
"bias) -> Tensor"
);
m
.
impl
(
"fp8_scaled_mm"
,
torch
::
kCUDA
,
&
fp8_scaled_mm
);
// lightning_attention_decode
// lightning_attention_decode
m
.
def
(
m
.
def
(
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
...
...
sgl-kernel/tests/test_fp8_gemm.py
0 → 100644
View file @
82392da8
import
unittest
import
torch
from
sgl_kernel
import
fp8_scaled_mm
def
torch_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
):
o
=
torch
.
matmul
(
a
.
to
(
torch
.
float32
),
b
.
to
(
torch
.
float32
))
o
=
o
.
to
(
torch
.
float32
)
temp1
=
o
*
scale_a
.
view
(
-
1
,
1
)
temp2
=
temp1
*
scale_b
.
view
(
1
,
-
1
)
final
=
temp2
.
to
(
out_dtype
)
if
bias
is
not
None
:
final
=
final
+
bias
.
view
(
1
,
-
1
)
return
final
class
TestFp8Gemm
(
unittest
.
TestCase
):
def
_test_accuracy_once
(
self
,
M
,
N
,
K
,
with_bias
,
out_dtype
,
device
):
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
a_fp32
=
(
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
)
a_fp8
=
a_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
b_fp32
=
(
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
)
b_fp8
=
b_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
scale_a
=
torch
.
randn
((
M
,),
device
=
device
,
dtype
=
torch
.
float32
)
*
0.001
scale_b
=
torch
.
randn
((
N
,),
device
=
device
,
dtype
=
torch
.
float32
)
*
0.001
if
with_bias
:
bias
=
torch
.
randn
((
N
,),
device
=
device
,
dtype
=
out_dtype
)
else
:
bias
=
None
o1
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
b_fp8
=
b_fp8
.
t
()
o
=
torch_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
o1
=
fp8_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
rtol
=
0.02
atol
=
1
torch
.
testing
.
assert_close
(
o
,
o1
,
rtol
=
rtol
,
atol
=
atol
)
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, with_bias=
{
with_bias
}
, out_dtype=
{
out_dtype
}
: OK"
)
def
test_accuracy
(
self
):
Ms
=
[
1
,
128
,
512
,
1024
,
4096
]
Ns
=
[
16
,
128
,
512
,
1024
,
4096
]
Ks
=
[
512
,
1024
,
4096
,
8192
,
16384
]
bias_opts
=
[
True
,
False
]
out_dtypes
=
[
torch
.
bfloat16
,
torch
.
float16
]
for
M
in
Ms
:
for
N
in
Ns
:
for
K
in
Ks
:
for
with_bias
in
bias_opts
:
for
out_dtype
in
out_dtypes
:
self
.
_test_accuracy_once
(
M
,
N
,
K
,
with_bias
,
out_dtype
,
"cuda"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment