Unverified Commit ebf495f0 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

sgl-kernel use cutlass latest version for fp8 blockwise gemm (#5207)

parent 7f875f12
...@@ -2,18 +2,22 @@ import argparse ...@@ -2,18 +2,22 @@ import argparse
import copy import copy
import itertools import itertools
import deep_gemm
import torch import torch
import triton import triton
from deep_gemm import get_col_major_tma_aligned_tensor
from sgl_kernel import fp8_blockwise_scaled_mm from sgl_kernel import fp8_blockwise_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from sglang.srt.layers.quantization.fp8_kernel import w8a8_block_fp8_matmul
def get_weight_shapes(args): def get_weight_shapes(args):
models_tps = list(itertools.product(args.models, args.tp_sizes)) models_tps = list(itertools.product(args.models, args.tp_sizes))
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model. # NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
# cannot TP # cannot TP
total = [ total = [
# (512 + 64, 7168), # this weight is not supported by current kernel (512 + 64, 7168),
((128 + 64) * 128, 7168), ((128 + 64) * 128, 7168),
(128 * (128 + 128), 512), (128 * (128 + 128), 512),
(7168, 16384), (7168, 16384),
...@@ -52,6 +56,23 @@ def cdiv(a: int, b: int) -> int: ...@@ -52,6 +56,23 @@ def cdiv(a: int, b: int) -> int:
return -(a // -b) return -(a // -b)
def fp8_gemm_deepgemm(
x_fp8: torch.Tensor,
x_scale: torch.Tensor,
y_fp8: torch.Tensor,
y_scale: torch.Tensor,
m: int,
n: int,
k: int,
):
"""DeepGEMM implementation of FP8 GEMM"""
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
# Run DeepGEMM kernel
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
return out
def scale_shape(shape, group_shape): def scale_shape(shape, group_shape):
assert len(shape) == len(group_shape) assert len(shape) == len(group_shape)
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape))) return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
...@@ -60,12 +81,12 @@ def scale_shape(shape, group_shape): ...@@ -60,12 +81,12 @@ def scale_shape(shape, group_shape):
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["batch_size"], x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048], x_vals=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
x_log=False, x_log=False,
line_arg="provider", line_arg="provider",
line_vals=["vllm", "sgl-kernel"], line_vals=["vllm", "sgl-kernel", "triton", "deepgemm"],
line_names=["vllm fp8 blockwise gemm", "sgl-kernel fp8 blockwise gemm"], line_names=["vllm", "sgl-kernel", "sglang triton", "deepgemm"],
styles=[("blue", "-"), ("orange", "-")], styles=[("blue", "-"), ("orange", "-"), ("red", "-"), ("yellow", "-")],
ylabel="GB/s", ylabel="GB/s",
plot_name="fp8 blockwise scaled matmul", plot_name="fp8 blockwise scaled matmul",
args={}, args={},
...@@ -80,7 +101,7 @@ def benchmark(batch_size, provider, N, K): ...@@ -80,7 +101,7 @@ def benchmark(batch_size, provider, N, K):
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max b_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t() b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
scale_a_group_shape = (1, 128) scale_a_group_shape = (1, 128)
scale_b_group_shape = (128, 128) scale_b_group_shape = (128, 128)
...@@ -89,11 +110,11 @@ def benchmark(batch_size, provider, N, K): ...@@ -89,11 +110,11 @@ def benchmark(batch_size, provider, N, K):
scale_a = torch.randn(scale_a_shape, device="cuda", dtype=torch.float32) scale_a = torch.randn(scale_a_shape, device="cuda", dtype=torch.float32)
scale_b = torch.randn(scale_b_shape, device="cuda", dtype=torch.float32) scale_b = torch.randn(scale_b_shape, device="cuda", dtype=torch.float32)
scale_a = scale_a.t().contiguous().t()
scale_b = scale_b.t().contiguous().t()
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if provider == "sgl-kernel": if provider == "sgl-kernel":
scale_a = scale_a.t().contiguous().t()
b_fp8, scale_b = b_fp8.t(), scale_b.t()
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_blockwise_scaled_mm( lambda: fp8_blockwise_scaled_mm(
a_fp8, b_fp8, scale_a, scale_b, torch.float16 a_fp8, b_fp8, scale_a, scale_b, torch.float16
...@@ -101,19 +122,28 @@ def benchmark(batch_size, provider, N, K): ...@@ -101,19 +122,28 @@ def benchmark(batch_size, provider, N, K):
quantiles=quantiles, quantiles=quantiles,
) )
if provider == "vllm": if provider == "vllm":
scale_a = scale_a.t().contiguous().t()
b_fp8, scale_b = b_fp8.t(), scale_b.t()
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16), lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
quantiles=quantiles, quantiles=quantiles,
) )
gbps = ( if provider == "triton":
lambda ms: ( ms, min_ms, max_ms = triton.testing.do_bench(
(2 * M * N * K - M * N) * a_fp8.element_size() lambda: w8a8_block_fp8_matmul(
+ (3 * M * N) * scale_a.element_size() a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
),
quantiles=quantiles,
) )
* 1e-9 if provider == "deepgemm":
/ (ms * 1e-3) scale_a_col_major = get_col_major_tma_aligned_tensor(scale_a.clone())
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_gemm_deepgemm(
a_fp8, scale_a_col_major, b_fp8, scale_b, M, N, K
),
quantiles=quantiles,
) )
return gbps(ms), gbps(max_ms), gbps(min_ms) return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
if __name__ == "__main__": if __name__ == "__main__":
...@@ -136,6 +166,9 @@ if __name__ == "__main__": ...@@ -136,6 +166,9 @@ if __name__ == "__main__":
NK_model_names = get_weight_shapes(args) NK_model_names = get_weight_shapes(args)
for N, K, model_name in NK_model_names: for N, K, model_name in NK_model_names:
if N % 128 != 0 or K % 128 != 0:
print(f"Skip {N=}, {K=} now")
continue
print(f"{model_name} N={N} K={K}: ") print(f"{model_name} N={N} K={K}: ")
benchmark.run( benchmark.run(
print_data=True, print_data=True,
......
// Adapt from
// https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/collective/collective_buildler.hpp
// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl
// clang-format off
#pragma once
#include <cutlass/gemm/collective/builders/sm90_gmma_builder.inl>
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_SS (BlockScaled Builders)
template <
class ElementA,
class GmemLayoutATag,
int AlignmentA,
class ElementB,
class GmemLayoutBTag,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
int ScaleGranularityM
>
struct CollectiveBuilder<
arch::Sm90,
arch::OpClassTensorOp,
ElementA,
GmemLayoutATag,
AlignmentA,
ElementB,
GmemLayoutBTag,
AlignmentB,
ElementAccumulator,
TileShape_MNK,
ClusterShape_MNK,
StageCountType,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>,
cute::enable_if_t<
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
> {
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedPingpong>);
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static_assert((!IsFP8Input || !IsArrayOfPointersGemm),
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now.");
// For fp32 types, map to tf32 MMA value type
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>;
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
using CollectiveOp = CollectiveMma<
DispatchPolicy,
TileShape_MNK,
ElementA,
TagToStrideA_t<GmemLayoutATag>,
ElementB,
TagToStrideB_t<GmemLayoutBTag>,
TiledMma,
GmemTiledCopyA,
SmemLayoutAtomA,
SmemCopyAtomA,
cute::identity,
GmemTiledCopyB,
SmemLayoutAtomB,
SmemCopyAtomB,
cute::identity
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
// Adapt from https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
#pragma once
#include <cutlass/gemm/dispatch_policy.hpp>
namespace cutlass::gemm {
//////////////////////////////////////////////////////////////////////////////
// FP8 related policies (including Blocked Scaled Accumulation)
// `ScaleGranularityM` specifies scaling granularity along M, while zero-value
// `ScaleGranularityM` indicates that scaling granularity is
// `size<0>(TileShape_MNK{})` along M.
template <int ScaleGranularityM = 0>
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelTmaWarpSpecializedCooperative {};
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
// specialized dynamic schedule For FP8 kernels with Block Scaling
template <
int Stages_,
class ClusterShape_ = Shape<_1, _1, _1>,
class KernelSchedule = KernelTmaWarpSpecialized,
int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M,
// while zero-value `ScaleGranularityM` indicates that scaling
// granularity is `size<0>(TileShape_MNK{})` along M.
>
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
static_assert(
cute::
is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>,
"KernelSchedule must be one of the warp specialized policies");
};
//////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm
...@@ -30,13 +30,16 @@ ...@@ -30,13 +30,16 @@
#include <cutlass/gemm/kernel/gemm_universal.hpp> #include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp> #include <cutlass/util/packed_stride.hpp>
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "utils.h" #include "utils.h"
using namespace cute; using namespace cute;
template <typename OutType, typename TileShape, typename ClusterShape, int ScaleGranularityM = 1> template <
typename SchedulerType,
typename OutType,
typename TileShape,
typename ClusterShape,
typename ScaleGranularity>
void launch_sm90_fp8_blockwise_scaled_mm( void launch_sm90_fp8_blockwise_scaled_mm(
torch::Tensor& out, torch::Tensor& out,
const torch::Tensor& a, const torch::Tensor& a,
...@@ -63,6 +66,9 @@ void launch_sm90_fp8_blockwise_scaled_mm( ...@@ -63,6 +66,9 @@ void launch_sm90_fp8_blockwise_scaled_mm(
using LayoutD = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor;
constexpr int AlignmentD = AlignmentC; constexpr int AlignmentD = AlignmentC;
static constexpr int ScaleGranularityM = size<0>(ScaleGranularity{});
static constexpr int ScaleGranularityN = size<1>(ScaleGranularity{});
using ArchTag = cutlass::arch::Sm90; using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp; using OperatorClass = cutlass::arch::OpClassTensorOp;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
...@@ -70,7 +76,7 @@ void launch_sm90_fp8_blockwise_scaled_mm( ...@@ -70,7 +76,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>; using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
using KernelSchedule = using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>; cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM, ScaleGranularityN>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, ArchTag,
OperatorClass, OperatorClass,
...@@ -108,7 +114,7 @@ void launch_sm90_fp8_blockwise_scaled_mm( ...@@ -108,7 +114,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
Shape<int, int, int, int>, // Indicates ProblemShape Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveMainloop,
CollectiveEpilogue, CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>; SchedulerType>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
Gemm gemm_op; Gemm gemm_op;
...@@ -299,8 +305,26 @@ void sm90_fp8_blockwise_dispatch_shape( ...@@ -299,8 +305,26 @@ void sm90_fp8_blockwise_dispatch_shape(
const torch::Tensor& scales_a, const torch::Tensor& scales_a,
const torch::Tensor& scales_b) { const torch::Tensor& scales_b) {
using TileShape = Shape<_128, _128, _128>; using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _2, _1>;
launch_sm90_fp8_blockwise_scaled_mm<OutType, TileShape, ClusterShape>(out, a, b, scales_a, scales_b); using ScaleGranularity = Shape<_1, _128, _128>;
auto k = a.size(1);
auto n = b.size(1);
if (k > 3 * n) {
launch_sm90_fp8_blockwise_scaled_mm<
cutlass::gemm::StreamKScheduler,
OutType,
TileShape,
ClusterShape,
ScaleGranularity>(out, a, b, scales_a, scales_b);
} else {
launch_sm90_fp8_blockwise_scaled_mm<
cutlass::gemm::PersistentScheduler,
OutType,
TileShape,
ClusterShape,
ScaleGranularity>(out, a, b, scales_a, scales_b);
}
} }
template <typename OutType> template <typename OutType>
...@@ -372,10 +396,11 @@ torch::Tensor fp8_blockwise_scaled_mm( ...@@ -372,10 +396,11 @@ torch::Tensor fp8_blockwise_scaled_mm(
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined CUDA_VERSION && CUDA_VERSION >= 12000 #if defined CUDA_VERSION && CUDA_VERSION >= 12000
if (sm_version == 90) { if (sm_version == 90) {
torch::Tensor scales_b_contiguous = scales_b.contiguous();
if (out_dtype == torch::kBFloat16) { if (out_dtype == torch::kBFloat16) {
sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b); sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b_contiguous);
} else { } else {
sm90_fp8_blockwise_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b); sm90_fp8_blockwise_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b_contiguous);
} }
return out; return out;
} }
......
...@@ -82,9 +82,9 @@ def _test_accuracy_once(M, N, K, out_dtype, device): ...@@ -82,9 +82,9 @@ def _test_accuracy_once(M, N, K, out_dtype, device):
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK") print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096]) @pytest.mark.parametrize("M", [1, 3, 5, 127, 128, 512, 1024, 4096])
@pytest.mark.parametrize("N", [128, 512, 1024, 4096]) @pytest.mark.parametrize("N", [128, 512, 1024, 4096, 8192, 14080])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384]) @pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 14080, 16384])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
def test_accuracy(M, N, K, out_dtype): def test_accuracy(M, N, K, out_dtype):
_test_accuracy_once(M, N, K, out_dtype, "cuda") _test_accuracy_once(M, N, K, out_dtype, "cuda")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment