Unverified Commit dc48c4c0 authored by Qi Yuhang's avatar Qi Yuhang Committed by GitHub
Browse files

[sgl-kernel][2/N]Support Expert Specialization Grouped GEMM (#11534)

parent 6dc9ca8c
...@@ -68,24 +68,58 @@ void es_fp8_blockwise_scaled_grouped_mm( ...@@ -68,24 +68,58 @@ void es_fp8_blockwise_scaled_grouped_mm(
torch::Tensor lm_problem_sizes = torch::empty({num_experts, 3}, options_int32); torch::Tensor lm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
torch::Tensor mm_problem_sizes = torch::empty({num_experts, 3}, options_int32); torch::Tensor mm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
torch::Tensor hm_problem_sizes = torch::empty({num_experts, 3}, options_int32); torch::Tensor hm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
out_ptrs, const std::string H20_device_type_str("NVIDIA H20");
a_ptrs, bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
b_ptrs, at::cuda::CUDAGuard device_guard{(char)a.get_device()};
a_scales_ptrs, cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
b_scales_ptrs,
layout_sfa, if (output.dtype() == torch::kBFloat16) {
layout_sfb, expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute<cutlass::bfloat16_t>(
lm_problem_sizes, out_ptrs,
mm_problem_sizes, a_ptrs,
hm_problem_sizes, b_ptrs,
output, a_scales_ptrs,
a, b_scales_ptrs,
b, layout_sfa,
scales_a, layout_sfb,
scales_b, lm_problem_sizes,
problem_sizes, mm_problem_sizes,
expert_offsets); hm_problem_sizes,
output,
a,
b,
scales_a,
scales_b,
problem_sizes,
expert_offsets,
is_h20_device,
stream);
} else if (output.dtype() == torch::kFloat16) {
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute<cutlass::half_t>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
layout_sfa,
layout_sfb,
lm_problem_sizes,
mm_problem_sizes,
hm_problem_sizes,
output,
a,
b,
scales_a,
scales_b,
problem_sizes,
expert_offsets,
is_h20_device,
stream);
} else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
if (output.dtype() == torch::kBFloat16) { if (output.dtype() == torch::kBFloat16) {
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype<cutlass::bfloat16_t>( expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype<cutlass::bfloat16_t>(
out_ptrs, out_ptrs,
...@@ -100,7 +134,9 @@ void es_fp8_blockwise_scaled_grouped_mm( ...@@ -100,7 +134,9 @@ void es_fp8_blockwise_scaled_grouped_mm(
layout_sfb, layout_sfb,
lm_problem_sizes, lm_problem_sizes,
mm_problem_sizes, mm_problem_sizes,
hm_problem_sizes); hm_problem_sizes,
is_h20_device,
stream);
} else if (output.dtype() == torch::kFloat16) { } else if (output.dtype() == torch::kFloat16) {
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype<cutlass::half_t>( expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype<cutlass::half_t>(
out_ptrs, out_ptrs,
...@@ -115,7 +151,9 @@ void es_fp8_blockwise_scaled_grouped_mm( ...@@ -115,7 +151,9 @@ void es_fp8_blockwise_scaled_grouped_mm(
layout_sfb, layout_sfb,
lm_problem_sizes, lm_problem_sizes,
mm_problem_sizes, mm_problem_sizes,
hm_problem_sizes); hm_problem_sizes,
is_h20_device,
stream);
} else { } else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <torch/all.h> #include <torch/all.h>
#include <cassert>
#include <iostream> #include <iostream>
#include <string> #include <string>
...@@ -14,6 +15,7 @@ namespace expert_specialization { ...@@ -14,6 +15,7 @@ namespace expert_specialization {
using namespace cute; using namespace cute;
template <typename T>
void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute( void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
// Output // Output
torch::Tensor& out_ptrs, torch::Tensor& out_ptrs,
...@@ -33,15 +35,14 @@ void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute( ...@@ -33,15 +35,14 @@ void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
torch::Tensor const& problem_sizes, torch::Tensor const& problem_sizes,
torch::Tensor const& expert_offsets) { torch::Tensor const& expert_offsets,
bool is_h20_device,
cudaStream_t stream) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
const std::string H20_device_type_str("NVIDIA H20");
bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
// Creat Scale Factor Layout Functor // Creat Scale Factor Layout Functor
using LayoutSFA = typename PerfConfigMiddleMH20::LayoutSFA; using LayoutSFA = typename PerfConfigMiddleMH20::LayoutSFA;
using LayoutSFB = typename PerfConfigMiddleMH20::LayoutSFB; using LayoutSFB = typename PerfConfigMiddleMH20::LayoutSFB;
...@@ -49,74 +50,38 @@ void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute( ...@@ -49,74 +50,38 @@ void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())); reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()));
int num_experts = (int)expert_offsets.size(0); int num_experts = (int)expert_offsets.size(0);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); TORCH_CHECK(num_experts <= 1024, "Expert more than 1024"); // Max threads per block is 1024
// Dispatch
if (out_tensors.dtype() == torch::kBFloat16) { struct Fp8BlockwiseGroupedGemmOffsetFunctor<cutlass::float_e4m3_t, float, T> of(
struct Fp8BlockwiseGroupedGemmOffsetFunctor<cutlass::float_e4m3_t, float, cutlass::bfloat16_t> of( static_cast<int*>(expert_offsets.data_ptr()),
static_cast<int*>(expert_offsets.data_ptr()), static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()),
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()),
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), static_cast<T*>(out_tensors.data_ptr()),
static_cast<cutlass::bfloat16_t*>(out_tensors.data_ptr()), static_cast<float*>(a_scales.data_ptr()),
static_cast<float*>(a_scales.data_ptr()), static_cast<float*>(b_scales.data_ptr()),
static_cast<float*>(b_scales.data_ptr()), static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()),
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()),
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), static_cast<float**>(a_scales_ptrs.data_ptr()),
static_cast<float**>(a_scales_ptrs.data_ptr()), static_cast<float**>(b_scales_ptrs.data_ptr()),
static_cast<float**>(b_scales_ptrs.data_ptr()), static_cast<T**>(out_ptrs.data_ptr()));
static_cast<cutlass::bfloat16_t**>(out_ptrs.data_ptr())); if (!is_h20_device) {
if (!is_h20_device) { struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMHx00> lm_psf(
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMHx00> lm_psf( static_cast<int*>(lm_problem_sizes.data_ptr()));
static_cast<int*>(lm_problem_sizes.data_ptr())); struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMHx00> mm_psf(
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMHx00> mm_psf( static_cast<int*>(mm_problem_sizes.data_ptr()));
static_cast<int*>(mm_problem_sizes.data_ptr())); struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMHx00> hm_psf(
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMHx00> hm_psf( static_cast<int*>(hm_problem_sizes.data_ptr()));
static_cast<int*>(hm_problem_sizes.data_ptr())); groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>( static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
} else {
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> lm_psf(
static_cast<int*>(lm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> mm_psf(
static_cast<int*>(mm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> hm_psf(
static_cast<int*>(hm_problem_sizes.data_ptr()));
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
}
} else if (out_tensors.dtype() == torch::kFloat16) {
struct Fp8BlockwiseGroupedGemmOffsetFunctor<cutlass::float_e4m3_t, float, cutlass::half_t> of(
static_cast<int*>(expert_offsets.data_ptr()),
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()),
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()),
static_cast<cutlass::half_t*>(out_tensors.data_ptr()),
static_cast<float*>(a_scales.data_ptr()),
static_cast<float*>(b_scales.data_ptr()),
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()),
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()),
static_cast<float**>(a_scales_ptrs.data_ptr()),
static_cast<float**>(b_scales_ptrs.data_ptr()),
static_cast<cutlass::half_t**>(out_ptrs.data_ptr()));
if (!is_h20_device) {
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMHx00> lm_psf(
static_cast<int*>(lm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMHx00> mm_psf(
static_cast<int*>(mm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMHx00> hm_psf(
static_cast<int*>(hm_problem_sizes.data_ptr()));
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
} else {
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> lm_psf(
static_cast<int*>(lm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> mm_psf(
static_cast<int*>(mm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> hm_psf(
static_cast<int*>(hm_problem_sizes.data_ptr()));
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
}
} else { } else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> lm_psf(
static_cast<int*>(lm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> mm_psf(
static_cast<int*>(mm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> hm_psf(
static_cast<int*>(hm_problem_sizes.data_ptr()));
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
} }
} }
...@@ -132,7 +97,8 @@ void launch_sm90_fp8_blockwise_scaled_group_mm( ...@@ -132,7 +97,8 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
const torch::Tensor& stride_d, const torch::Tensor& stride_d,
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb, const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes) { const torch::Tensor& problem_sizes,
cudaStream_t stream) {
using ElementA = typename GemmTraits::ElementA; using ElementA = typename GemmTraits::ElementA;
using StrideA = typename GemmTraits::StrideA; using StrideA = typename GemmTraits::StrideA;
using ElementB = typename GemmTraits::ElementB; using ElementB = typename GemmTraits::ElementB;
...@@ -174,9 +140,6 @@ void launch_sm90_fp8_blockwise_scaled_group_mm( ...@@ -174,9 +140,6 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
epilogue_args, epilogue_args,
hw_info}; hw_info};
at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
auto can_implement_status = gemm_op.can_implement(args); auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM"); TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
...@@ -205,7 +168,9 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( ...@@ -205,7 +168,9 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
const torch::Tensor& layout_sfb, const torch::Tensor& layout_sfb,
const torch::Tensor& lm_problem_sizes, const torch::Tensor& lm_problem_sizes,
const torch::Tensor& mm_problem_sizes, const torch::Tensor& mm_problem_sizes,
const torch::Tensor& hm_problem_sizes) { const torch::Tensor& hm_problem_sizes,
bool is_h20_device,
cudaStream_t stream) {
using LowMGemmH20Traits = using LowMGemmH20Traits =
ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::ColumnMajor, PerfConfigLowMH20>; ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::ColumnMajor, PerfConfigLowMH20>;
using LowMGemmHx00Traits = using LowMGemmHx00Traits =
...@@ -221,9 +186,6 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( ...@@ -221,9 +186,6 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
using HighMGemmHx00Traits = using HighMGemmHx00Traits =
ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::RowMajor, PerfConfigHighMHx00>; ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::RowMajor, PerfConfigHighMHx00>;
const std::string H20_device_type_str("NVIDIA H20");
bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
if (!is_h20_device) { if (!is_h20_device) {
launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmHx00Traits>( launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmHx00Traits>(
out_ptrs, out_ptrs,
...@@ -236,7 +198,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( ...@@ -236,7 +198,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d, stride_d,
layout_sfb, layout_sfb,
layout_sfa, layout_sfa,
lm_problem_sizes); lm_problem_sizes,
stream);
} else { } else {
launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmH20Traits>( launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmH20Traits>(
out_ptrs, out_ptrs,
...@@ -249,7 +212,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( ...@@ -249,7 +212,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d, stride_d,
layout_sfb, layout_sfb,
layout_sfa, layout_sfa,
lm_problem_sizes); lm_problem_sizes,
stream);
} }
if (!is_h20_device) { if (!is_h20_device) {
...@@ -264,7 +228,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( ...@@ -264,7 +228,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d, stride_d,
layout_sfb, layout_sfb,
layout_sfa, layout_sfa,
mm_problem_sizes); mm_problem_sizes,
stream);
} else { } else {
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>( launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>(
out_ptrs, out_ptrs,
...@@ -277,7 +242,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( ...@@ -277,7 +242,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d, stride_d,
layout_sfa, layout_sfa,
layout_sfb, layout_sfb,
mm_problem_sizes); mm_problem_sizes,
stream);
} }
if (!is_h20_device) { if (!is_h20_device) {
...@@ -292,7 +258,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( ...@@ -292,7 +258,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d, stride_d,
layout_sfa, layout_sfa,
layout_sfb, layout_sfb,
hm_problem_sizes); hm_problem_sizes,
stream);
} else { } else {
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmH20Traits>( launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmH20Traits>(
out_ptrs, out_ptrs,
...@@ -305,7 +272,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( ...@@ -305,7 +272,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d, stride_d,
layout_sfa, layout_sfa,
layout_sfb, layout_sfb,
hm_problem_sizes); hm_problem_sizes,
stream);
} }
} }
......
...@@ -244,7 +244,7 @@ from sgl_kernel.elementwise import ( ...@@ -244,7 +244,7 @@ from sgl_kernel.elementwise import (
rmsnorm, rmsnorm,
silu_and_mul, silu_and_mul,
) )
from sgl_kernel.expert_specilization import es_fp8_blockwise_scaled_grouped_mm from sgl_kernel.expert_specialization import es_fp8_blockwise_scaled_grouped_mm
from sgl_kernel.fused_moe import fused_marlin_moe from sgl_kernel.fused_moe import fused_marlin_moe
from sgl_kernel.gemm import ( from sgl_kernel.gemm import (
awq_dequantize, awq_dequantize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment