"tests/vscode:/vscode.git/clone" did not exist on "21b3671bbc508662561ae95a418a26dbe71db356"
Unverified Commit 841810f2 authored by henryg's avatar henryg Committed by GitHub
Browse files

[Perf] Tunings for SM100 FP8 CUTLASS kernel (#8818)

parent 733446dd
import argparse
import copy
import itertools
from typing import Optional, Tuple
import torch
import triton
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
from sgl_kernel import sgl_per_tensor_quant_fp8
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
......@@ -69,6 +71,21 @@ WEIGHT_SHAPES = {
}
def sglang_scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
fp8_type_: torch.dtype = torch.float8_e4m3fn
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
is_static = True
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
is_static = False
sgl_per_tensor_quant_fp8(input, output, scale, is_static)
return output, scale
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
......@@ -100,19 +117,22 @@ def benchmark(batch_size, provider, N, K):
b = torch.ones((N, K), device="cuda") * 5.0
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
quantiles = [0.5, 0.2, 0.8]
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
if "vllm-fp8" in provider:
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
quantiles=quantiles,
)
elif "sglang-fp8" in provider:
a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_scaled_mm(
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
......
......@@ -48,6 +48,7 @@ limitations under the License.
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
#include "math.hpp"
#include "utils.h"
using namespace cute;
......@@ -1019,8 +1020,18 @@ void sm100_fp8_dispatch_bias(
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using CTAShape = Shape<_256, _128, _64>;
using ClusterShape = Shape<_2, _2, _1>;
using CTAShapeDefault = Shape<_256, _128, _64>;
using ClusterShapeDefault = Shape<_2, _2, _1>;
using CTAShape256 = Shape<_128, _128, _128>;
using ClusterShape256 = Shape<_2, _1, _1>;
using CTAShape64 = Shape<_64, _64, _128>;
using ClusterShape64 = Shape<_1, _1, _1>;
using CTAShape16 = Shape<_64, _64, _128>;
using ClusterShape16 = Shape<_1, _4, _1>;
using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileSchedulerType = void;
......@@ -1029,30 +1040,121 @@ void sm100_fp8_dispatch_bias(
using ElementOutput = OutType;
using AccumElementType = float;
// Gemm type with bias
using BiasGemmDefault = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShapeDefault,
ClusterShapeDefault,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;
using BiasGemm256 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape256,
ClusterShape256,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;
using BiasGemm64 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape64,
ClusterShape64,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;
using BiasGemm16 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape16,
ClusterShape16,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;
// Gemm type without bias
using GemmDefault = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShapeDefault,
ClusterShapeDefault,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;
using Gemm256 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape256,
ClusterShape256,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;
using Gemm64 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape64,
ClusterShape64,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;
using Gemm16 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape16,
ClusterShape16,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;
// next power of 2 (minimum 16)
uint32_t const m = a.size(0);
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (bias) {
using Gemm = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape,
ClusterShape,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;
return launch_sm100_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
if (mp2 <= 16) {
// m in [1, 16]
return launch_sm100_fp8_scaled_mm<BiasGemm16, true>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 64) {
// m in (16, 64]
return launch_sm100_fp8_scaled_mm<BiasGemm64, true>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 256) {
// m in (64, 256]
return launch_sm100_fp8_scaled_mm<BiasGemm256, true>(out, a, b, scales_a, scales_b, bias);
} else {
// m in (256, inf]
return launch_sm100_fp8_scaled_mm<BiasGemmDefault, true>(out, a, b, scales_a, scales_b, bias);
}
} else {
using Gemm = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape,
ClusterShape,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;
return launch_sm100_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
if (mp2 <= 16) {
// m in [1, 16]
return launch_sm100_fp8_scaled_mm<Gemm16, false>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 64) {
// m in (16, 64]
return launch_sm100_fp8_scaled_mm<Gemm64, false>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 256) {
// m in (64, 256]
return launch_sm100_fp8_scaled_mm<Gemm256, false>(out, a, b, scales_a, scales_b, bias);
} else {
return launch_sm100_fp8_scaled_mm<GemmDefault, false>(out, a, b, scales_a, scales_b, bias);
}
}
}
......
#pragma once
#include <climits>
#include <iostream>
inline constexpr uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
template <typename A, typename B>
static inline constexpr auto div_ceil(A a, B b) {
return (a + b - 1) / b;
}
// Round a down to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template <typename T>
inline constexpr T round_to_previous_multiple_of(T a, T b) {
return a % b == 0 ? a : (a / b) * b;
}
// Round a up to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template <typename T>
inline constexpr T round_to_next_multiple_of(T a, T b) {
return a % b == 0 ? a : ((a / b) + 1) * b;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment