"symphony/client/pyproject.toml" did not exist on "02782735c370f2cfe36b641f9eb3c9d0cf2e15f8"
Unverified Commit 82392da8 authored by HandH1998's avatar HandH1998 Committed by GitHub
Browse files

support w8a8 fp8 kernel with CUTLASS (#3047)


Co-authored-by: default avataryych0745 <1398089567@qq.com>
parent 95f789ad
import argparse
import copy
import itertools
import torch
import triton
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536
# TP1 shapes
WEIGHT_SHAPES = {
"meta-llama/Llama-3.1-8B-Instruct": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-3.3-70B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
"mistralai/Mistral-Large-Instruct-2407": [
([12288, 14336], 1),
([12288, 12288], 0),
([12288, 57344], 1),
([28672, 12288], 0),
],
"Qwen/Qwen2.5-7B-Instruct": [
([3584, 4608], 1),
([3584, 3584], 0),
([3584, 37888], 1),
([18944, 3584], 0),
],
"Qwen/Qwen2.5-32B-Instruct": [
([5120, 7168], 1),
([5120, 5120], 0),
([5120, 55296], 1),
([27648, 5120], 0),
],
"Qwen/Qwen2.5-72B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 59136], 1),
([29568, 8192], 0),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
([2048, 3072], 1),
([2048, 4096], 1),
([2048, 2048], 0),
([2048, 576], 0),
([2048, 21888], 1),
([10944, 2048], 0),
([2048, 2816], 1),
([1408, 2048], 0),
],
}
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=[
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
line_names=[
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
ylabel="GB/s",
plot_name="fp8 scaled matmul",
args={},
)
)
def benchmark(batch_size, provider, N, K):
# M, N, K = batch_size, 4096, 8192
M = batch_size
a = torch.ones((M, K), device="cuda") * 5.0
b = torch.ones((N, K), device="cuda") * 5.0
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
quantiles = [0.5, 0.2, 0.8]
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
if "vllm-fp8" in provider:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
quantiles=quantiles,
)
elif "sglang-fp8" in provider:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_scaled_mm(
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
),
quantiles=quantiles,
)
gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)
def prepare_shapes(args):
KN_model_names = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
assert model in WEIGHT_SHAPES
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KN.append(model)
KN_model_names.append(KN)
return KN_model_names
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K
)
print("Benchmark finished!")
......@@ -56,6 +56,7 @@ include_dirs = [
turbomind.resolve(),
turbomind.resolve() / "src",
]
nvcc_flags = [
"-DNDEBUG",
f"-DOPERATOR_NAMESPACE={operator_namespace}",
......@@ -82,6 +83,7 @@ sources = [
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu",
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu",
"3rdparty/flashinfer/csrc/activation.cu",
......
......@@ -2,6 +2,7 @@ from sgl_kernel.ops import (
bmm_fp8,
custom_dispose,
custom_reduce,
fp8_scaled_mm,
fused_add_rmsnorm,
gelu_and_mul,
gelu_tanh_and_mul,
......@@ -27,6 +28,7 @@ __all__ = [
"bmm_fp8",
"custom_dispose",
"custom_reduce",
"fp8_scaled_mm",
"fused_add_rmsnorm",
"gelu_and_mul",
"gelu_tanh_and_mul",
......
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_template.h
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm89.h
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm90.h
#include <ATen/cuda/CUDAContext.h>
#include <cudaTypedefs.h>
#include <cutlass/arch/arch.h>
#include <cutlass/arch/memory.h>
#include <cutlass/arch/mma.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/thread/activation.h>
#include <cutlass/epilogue/thread/linear_combination.h>
#include <cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
#include <cutlass/gemm/thread/mma.h>
#include <cutlass/layout/matrix.h>
#include <cutlass/matrix_coord.h>
#include <cutlass/numeric_types.h>
#include <cutlass/tensor_ref.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <cutlass/epilogue/collective/collective_builder.hpp>
#include <cutlass/epilogue/collective/default_epilogue.hpp>
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
#include <cutlass/gemm/collective/collective_builder.hpp>
#include <cutlass/gemm/dispatch_policy.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
#include "utils.h"
using namespace cute;
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
template <typename ElementType, typename OutElementType, typename AccumElementType, typename CtaShape,
typename WarpShape, int Stages, bool WithBias, typename FP8MathOperator = cutlass::arch::OpMultiplyAdd,
template <typename...> typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT,
typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>>
struct DeviceGemmFp8RowwiseSm89 {
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
using ElementA = ElementType;
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
using ElementB = ElementType;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ElementC = OutElementType;
using LayoutC = cutlass::layout::RowMajor;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
using ElementOutput = OutElementType;
using LayoutOutput = cutlass::layout::RowMajor;
static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
using ElementAccumulator = AccumElementType;
using ElementComputeEpilogue = float;
using ArchTag = cutlass::arch::Sm89;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
// Number of epilogue stages in EVT
static constexpr int EVTEpilogueStages = 1;
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<CtaShape, WarpShape, ElementC,
AlignmentC, EVTEpilogueStages>;
// Definition of EVT
using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch;
using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>;
using bScaleSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast<OutputTileThreadMap, ElementComputeEpilogue,
Stride<_0, _1, _0>>;
using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeBScale, accSrc, bScaleSrc>;
using ComputeAScale =
cutlass::epilogue::threadblock::VisitorCompute<cutlass::multiplies, ElementC, ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using aScaleSrc = cutlass::epilogue::threadblock::VisitorColBroadcast<OutputTileThreadMap, ElementComputeEpilogue,
Stride<_1, _0, _0>>;
using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeAScale, EpilogueBScale, aScaleSrc>;
// With bias
using biasSrc =
cutlass::epilogue::threadblock::VisitorRowBroadcast<OutputTileThreadMap, ElementOutput, Stride<_0, _1, _0>>;
using ComputeAScaleWithBias =
cutlass::epilogue::threadblock::VisitorCompute<cutlass::multiply_add, ElementC, ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using EpilogueAScaleWithBias =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAScaleWithBias, EpilogueBScale, aScaleSrc, biasSrc>;
using dTar = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementC, cutlass::FloatRoundStyle::round_to_nearest, Stride<int64_t, _1, _0>>;
using EpilogueStore =
typename cutlass::platform::conditional<WithBias,
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScaleWithBias>,
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScale>>::type;
using EpilogueOp = EpilogueStore;
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, ElementB, LayoutB,
cutlass::ComplexTransform::kNone, AlignmentB, ElementC, LayoutC, AlignmentC, ElementAccumulator,
ElementComputeEpilogue, OperatorClass, ArchTag, CtaShape, WarpShape, InstructionShape, EpilogueOp,
ThreadblockSwizzle, Stages, FP8MathOperator, EVTEpilogueStages>::GemmKernel;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
template <typename Gemm, bool WithBias>
typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using ElementT = typename Gemm::ElementA;
using ElementOutput = typename Gemm::ElementD;
using ElementComputeEpilogue = float;
int32_t m = a.size(0);
int32_t n = b.size(1);
int32_t k = a.size(1);
int64_t lda = a.stride(0);
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);
ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());
ElementOutput const* ptr_bias = nullptr;
if constexpr (WithBias) {
TORCH_CHECK(bias.has_value())
ptr_bias = reinterpret_cast<ElementOutput const*>(bias.value().data_ptr());
}
ElementOutput* ptr_d = reinterpret_cast<ElementOutput*>(out.data_ptr());
ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr());
ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr());
typename Gemm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, // Mode
{m, n, k}, // Problem size
1, // Split-k factor
{}, // Epilogue args
ptr_a, // a pointer
ptr_b, // b pointer
nullptr, // c pointer (unused)
nullptr, // d pointer (unused)
m * k, // batch stride a (unused)
n * k, // batch stride b (unused)
m * n, // batch stride c (unused)
m * n, // batch stride d (unused)
lda, // stride a
ldb, // stride b
ldc, // stride c (unused)
ldc); // stride d (unused)
if constexpr (WithBias) {
args.epilogue = {{
{
{}, // Accumulator
{ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
{} // Multiplies
},
{ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
{ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}},
{} // Multiplies
},
{ptr_d, {n, _1{}, _0{}}}};
} else {
args.epilogue = {{
{
{}, // Accumulator
{ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
{} // Multiplies
},
{ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
{} // Multiplies
},
{ptr_d, {n, _1{}, _0{}}}};
}
return args;
}
template <typename Gemm, bool WithBias>
void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
auto args = prepare_sm89_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
Gemm gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
auto status = gemm_op(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess)
}
template <typename OutType, typename CtaShape, typename WarpShape, int Stages>
void sm89_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using ElementInput = cutlass::float_e4m3_t;
using ElementOutput = OutType;
using AccumElementType = float;
if (bias) {
using Gemm = typename DeviceGemmFp8RowwiseSm89<ElementInput, ElementOutput, AccumElementType, CtaShape, WarpShape,
Stages, true>::Gemm;
return launch_sm89_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
} else {
using Gemm = typename DeviceGemmFp8RowwiseSm89<ElementInput, ElementOutput, AccumElementType, CtaShape, WarpShape,
Stages, false>::Gemm;
return launch_sm89_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
}
}
template <typename OutType>
void sm89_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
uint32_t const m = a.size(0);
uint32_t const n = out.size(1);
if (m == 1) {
if (n <= 8192) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
}
} else if (m <= 16) {
// M in (1, 16]
if (n <= 8192) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias);
} else if (n <= 16384) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
}
} else if (m <= 64) {
// M in (16, 64]
if (n <= 16384) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
}
} else if (m <= 128) {
// M in (64, 128]
if (n <= 8192) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias);
} else if (n <= 16384) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
}
} else if (m <= 256) {
// M in (128, 256]
if (n <= 8192) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<64, 32, 64>, 5>(out, a, b, scales_a, scales_b, bias);
} else if (n <= 16384) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>, 7>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 64, 128>,
cutlass::gemm::GemmShape<64, 32, 128>, 4>(out, a, b, scales_a, scales_b, bias);
}
} else if (m <= 512) {
// M in (256, 512)
if (n <= 16384) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>, 4>(out, a, b, scales_a, scales_b, bias);
}
} else {
// M in (512, inf)
if (n <= 8192) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>, 3>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias);
}
}
}
#endif
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
template <typename ElementType, typename OutElementType, typename AccumElementType, typename CTAShape,
typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
typename TileSchedulerType = void, bool WithBias = false>
struct DeviceGemmFp8RowwiseSm90 {
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
// A matrix configuration
using ElementA = ElementType; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
static constexpr int AlignmentA =
128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = ElementType; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
static constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = void; // Element type for C matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands
static constexpr int AlignmentC =
128 / cutlass::sizeof_bits<OutElementType>::value; // Memory access granularity/alignment of C matrices in
// units of elements (up to 16 bytes)
// Output matrix configuration
using ElementOutput = OutElementType; // Element type for output matrix operands
using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands
static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
// // Auxiliary matrix configuration and other fusion types
// using ElementBias = float;
// Multiply-accumulate blocking/pipelining details
using ElementAccumulator = AccumElementType; // Element type for internal accumulation
using ElementCompute = float; // Element type for compute
using ElementComputeEpilogue = float;
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = CTAShape; // Threadblock-level tile size
static constexpr bool PONG = false;
static constexpr bool FAST_ACCUM = true;
static constexpr bool USE_BIAS = false;
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized
// based on the tile size
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default
// setting in the Collective Builder
// Implement rowwise scaling epilogue.
using XScale =
cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue,
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
using WScale =
cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies,
ElementComputeEpilogue, // First stage output type.
ElementComputeEpilogue, // First stage input types.
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementOutput,
ElementComputeEpilogue, // Second stage input types.
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
// With bias
using ComputeWithBias =
cutlass::epilogue::fusion::Sm90Compute<cutlass::multiply_add, ElementOutput, ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>;
using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementComputeEpilogue, ElementC, LayoutC,
AlignmentC, ElementOutput, LayoutOutput, AlignmentOutput, cutlass::epilogue::TmaWarpSpecialized,
EpilogueEVT>::CollectiveOp;
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using FastDefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
using FastPongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using SlowAccum = DefaultSchedule;
using FastAccum = FastPongSchedule; // Default apply Pingpong
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
template <typename Gemm, bool WithBias>
typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using ElementT = typename Gemm::ElementA;
using ElementOutput = typename Gemm::ElementD;
using ElementComputeEpilogue = float;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
int32_t m = a.size(0);
int32_t n = b.size(1);
int32_t k = a.size(1);
ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());
ElementOutput const* ptr_bias = nullptr;
if constexpr (WithBias) {
TORCH_CHECK(bias.has_value())
ptr_bias = reinterpret_cast<ElementOutput const*>(bias.value().data_ptr());
}
ElementOutput* ptr_d = reinterpret_cast<ElementOutput*>(out.data_ptr());
ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr());
ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr());
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1));
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1));
StrideC stride_c;
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
{ptr_a, stride_a, ptr_b, stride_b},
{{}, // epilogue.thread
nullptr,
stride_c,
ptr_d,
stride_d}};
if constexpr (WithBias) {
args.epilogue.thread = {
{ptr_scales_a},
{
{ptr_scales_b},
{}, // Accumulator
{} // Multiplies
},
{ptr_bias},
{}, // Multiplies
};
} else {
args.epilogue.thread = {
{ptr_scales_a},
{
{ptr_scales_b},
{}, // Accumulator
{} // Multiplies
},
{}, // Multiplies
};
}
return args;
}
template <typename Gemm, bool WithBias>
void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
auto args = prepare_sm90_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
Gemm gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
auto status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess)
}
template <typename OutType, typename CTAShape, typename ClusterShape, typename MainloopScheduleType,
typename TileSchedulerType>
void sm90_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias, bool fast_accum = true,
bool use_persistent = false) {
using ElementInput = cutlass::float_e4m3_t;
using ElementOutput = OutType;
using AccumElementType = float;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
if (bias) {
using Gemm =
typename DeviceGemmFp8RowwiseSm90<ElementInput, ElementOutput, AccumElementType, CTAShape, ClusterShape,
MainloopScheduleType, EpilogueScheduleType, TileSchedulerType, true>::Gemm;
return launch_sm90_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
} else {
using Gemm =
typename DeviceGemmFp8RowwiseSm90<ElementInput, ElementOutput, AccumElementType, CTAShape, ClusterShape,
MainloopScheduleType, EpilogueScheduleType, TileSchedulerType, false>::Gemm;
return launch_sm90_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
}
}
template <typename OutType>
void sm90_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
uint32_t const m = a.size(0);
using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
using PersistentTileScheduler = cutlass::gemm::PersistentScheduler;
using BasicTileScheduler = void;
if (m <= 1) {
return sm90_fp8_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _8, _1>, FastBasicScheduler,
BasicTileScheduler>(out, a, b, scales_a, scales_b, bias);
}
if (m <= 64) {
// m in [1, 64]
return sm90_fp8_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _4, _1>, FastPingpongScheduler,
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} else if (m <= 256) {
// m in (64, 256]
return sm90_fp8_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _1, _1>, FastPingpongScheduler,
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} else if (m <= 1024) {
// m in (256, 1024]
return sm90_fp8_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_1, _1, _1>, FastPingpongScheduler,
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} else {
// m in (1024, inf)
return sm90_fp8_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_2, _1, _1>, FastPingpongScheduler,
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
}
}
#endif
torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias) {
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0,
"mat_a must be multiple of 16 bytes for memory alignment");
TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0,
"mat_b must be multiple of 16 bytes for memory alignment");
TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched");
TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched");
TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous");
TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous");
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32");
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32");
if (bias) {
TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched");
TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous");
TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype");
}
torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype));
TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment");
auto sm_version = getSMVersion();
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
if (sm_version >= 90) {
if (out_dtype == torch::kBFloat16) {
sm90_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
sm90_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
return out;
}
#endif
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
if (sm_version == 89) {
if (out_dtype == torch::kBFloat16) {
sm89_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
sm89_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
return out;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented fp8_scaled_mm for current compute capability: ", sm_version);
}
......@@ -40,6 +40,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias);
// fp8_scaled_mm
torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias);
// lightning_attention_decode
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
......
......@@ -71,6 +71,17 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
)
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernels.fp8_scaled_mm(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
bias,
)
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
torch.ops.sgl_kernels.lightning_attention_decode(
q, k, v, past_kv, slope, output, new_kv
......
......@@ -34,6 +34,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"bias) -> Tensor");
m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm);
// fp8_scaled_mm
m.def(
"fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
"bias) -> Tensor");
m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm);
// lightning_attention_decode
m.def(
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
......
import unittest
import torch
from sgl_kernel import fp8_scaled_mm
def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
o = torch.matmul(a.to(torch.float32), b.to(torch.float32))
o = o.to(torch.float32)
temp1 = o * scale_a.view(-1, 1)
temp2 = temp1 * scale_b.view(1, -1)
final = temp2.to(out_dtype)
if bias is not None:
final = final + bias.view(1, -1)
return final
class TestFp8Gemm(unittest.TestCase):
def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device):
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
)
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
b_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
)
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001
scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001
if with_bias:
bias = torch.randn((N,), device=device, dtype=out_dtype)
else:
bias = None
o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16)
b_fp8 = b_fp8.t()
o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
rtol = 0.02
atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
def test_accuracy(self):
Ms = [1, 128, 512, 1024, 4096]
Ns = [16, 128, 512, 1024, 4096]
Ks = [512, 1024, 4096, 8192, 16384]
bias_opts = [True, False]
out_dtypes = [torch.bfloat16, torch.float16]
for M in Ms:
for N in Ns:
for K in Ks:
for with_bias in bias_opts:
for out_dtype in out_dtypes:
self._test_accuracy_once(
M, N, K, with_bias, out_dtype, "cuda"
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment