"docker/vscode:/vscode.git/clone" did not exist on "1a5797c6d4491a879ea5285c4efc377664e0332d"
Unverified Commit dc48c4c0 authored by Qi Yuhang's avatar Qi Yuhang Committed by GitHub
Browse files

[sgl-kernel][2/N]Support Expert Specialization Grouped GEMM (#11534)

parent 6dc9ca8c
......@@ -68,24 +68,58 @@ void es_fp8_blockwise_scaled_grouped_mm(
torch::Tensor lm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
torch::Tensor mm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
torch::Tensor hm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
layout_sfa,
layout_sfb,
lm_problem_sizes,
mm_problem_sizes,
hm_problem_sizes,
output,
a,
b,
scales_a,
scales_b,
problem_sizes,
expert_offsets);
const std::string H20_device_type_str("NVIDIA H20");
bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
at::cuda::CUDAGuard device_guard{(char)a.get_device()};
cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
if (output.dtype() == torch::kBFloat16) {
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute<cutlass::bfloat16_t>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
layout_sfa,
layout_sfb,
lm_problem_sizes,
mm_problem_sizes,
hm_problem_sizes,
output,
a,
b,
scales_a,
scales_b,
problem_sizes,
expert_offsets,
is_h20_device,
stream);
} else if (output.dtype() == torch::kFloat16) {
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute<cutlass::half_t>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
layout_sfa,
layout_sfb,
lm_problem_sizes,
mm_problem_sizes,
hm_problem_sizes,
output,
a,
b,
scales_a,
scales_b,
problem_sizes,
expert_offsets,
is_h20_device,
stream);
} else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
if (output.dtype() == torch::kBFloat16) {
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype<cutlass::bfloat16_t>(
out_ptrs,
......@@ -100,7 +134,9 @@ void es_fp8_blockwise_scaled_grouped_mm(
layout_sfb,
lm_problem_sizes,
mm_problem_sizes,
hm_problem_sizes);
hm_problem_sizes,
is_h20_device,
stream);
} else if (output.dtype() == torch::kFloat16) {
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype<cutlass::half_t>(
out_ptrs,
......@@ -115,7 +151,9 @@ void es_fp8_blockwise_scaled_grouped_mm(
layout_sfb,
lm_problem_sizes,
mm_problem_sizes,
hm_problem_sizes);
hm_problem_sizes,
is_h20_device,
stream);
} else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
......
......@@ -3,6 +3,7 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <cassert>
#include <iostream>
#include <string>
......@@ -14,6 +15,7 @@ namespace expert_specialization {
using namespace cute;
template <typename T>
void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
// Output
torch::Tensor& out_ptrs,
......@@ -33,15 +35,14 @@ void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& problem_sizes,
torch::Tensor const& expert_offsets) {
torch::Tensor const& expert_offsets,
bool is_h20_device,
cudaStream_t stream) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
const std::string H20_device_type_str("NVIDIA H20");
bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
// Creat Scale Factor Layout Functor
using LayoutSFA = typename PerfConfigMiddleMH20::LayoutSFA;
using LayoutSFB = typename PerfConfigMiddleMH20::LayoutSFB;
......@@ -49,74 +50,38 @@ void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()));
int num_experts = (int)expert_offsets.size(0);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
// Dispatch
if (out_tensors.dtype() == torch::kBFloat16) {
struct Fp8BlockwiseGroupedGemmOffsetFunctor<cutlass::float_e4m3_t, float, cutlass::bfloat16_t> of(
static_cast<int*>(expert_offsets.data_ptr()),
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()),
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()),
static_cast<cutlass::bfloat16_t*>(out_tensors.data_ptr()),
static_cast<float*>(a_scales.data_ptr()),
static_cast<float*>(b_scales.data_ptr()),
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()),
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()),
static_cast<float**>(a_scales_ptrs.data_ptr()),
static_cast<float**>(b_scales_ptrs.data_ptr()),
static_cast<cutlass::bfloat16_t**>(out_ptrs.data_ptr()));
if (!is_h20_device) {
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMHx00> lm_psf(
static_cast<int*>(lm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMHx00> mm_psf(
static_cast<int*>(mm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMHx00> hm_psf(
static_cast<int*>(hm_problem_sizes.data_ptr()));
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
} else {
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> lm_psf(
static_cast<int*>(lm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> mm_psf(
static_cast<int*>(mm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> hm_psf(
static_cast<int*>(hm_problem_sizes.data_ptr()));
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
}
} else if (out_tensors.dtype() == torch::kFloat16) {
struct Fp8BlockwiseGroupedGemmOffsetFunctor<cutlass::float_e4m3_t, float, cutlass::half_t> of(
static_cast<int*>(expert_offsets.data_ptr()),
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()),
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()),
static_cast<cutlass::half_t*>(out_tensors.data_ptr()),
static_cast<float*>(a_scales.data_ptr()),
static_cast<float*>(b_scales.data_ptr()),
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()),
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()),
static_cast<float**>(a_scales_ptrs.data_ptr()),
static_cast<float**>(b_scales_ptrs.data_ptr()),
static_cast<cutlass::half_t**>(out_ptrs.data_ptr()));
if (!is_h20_device) {
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMHx00> lm_psf(
static_cast<int*>(lm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMHx00> mm_psf(
static_cast<int*>(mm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMHx00> hm_psf(
static_cast<int*>(hm_problem_sizes.data_ptr()));
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
} else {
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> lm_psf(
static_cast<int*>(lm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> mm_psf(
static_cast<int*>(mm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> hm_psf(
static_cast<int*>(hm_problem_sizes.data_ptr()));
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
}
TORCH_CHECK(num_experts <= 1024, "Expert more than 1024"); // Max threads per block is 1024
struct Fp8BlockwiseGroupedGemmOffsetFunctor<cutlass::float_e4m3_t, float, T> of(
static_cast<int*>(expert_offsets.data_ptr()),
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()),
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()),
static_cast<T*>(out_tensors.data_ptr()),
static_cast<float*>(a_scales.data_ptr()),
static_cast<float*>(b_scales.data_ptr()),
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()),
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()),
static_cast<float**>(a_scales_ptrs.data_ptr()),
static_cast<float**>(b_scales_ptrs.data_ptr()),
static_cast<T**>(out_ptrs.data_ptr()));
if (!is_h20_device) {
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMHx00> lm_psf(
static_cast<int*>(lm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMHx00> mm_psf(
static_cast<int*>(mm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMHx00> hm_psf(
static_cast<int*>(hm_problem_sizes.data_ptr()));
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
} else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> lm_psf(
static_cast<int*>(lm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> mm_psf(
static_cast<int*>(mm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> hm_psf(
static_cast<int*>(hm_problem_sizes.data_ptr()));
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
}
}
......@@ -132,7 +97,8 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
const torch::Tensor& stride_d,
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes) {
const torch::Tensor& problem_sizes,
cudaStream_t stream) {
using ElementA = typename GemmTraits::ElementA;
using StrideA = typename GemmTraits::StrideA;
using ElementB = typename GemmTraits::ElementB;
......@@ -174,9 +140,6 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
epilogue_args,
hw_info};
at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
......@@ -205,7 +168,9 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
const torch::Tensor& layout_sfb,
const torch::Tensor& lm_problem_sizes,
const torch::Tensor& mm_problem_sizes,
const torch::Tensor& hm_problem_sizes) {
const torch::Tensor& hm_problem_sizes,
bool is_h20_device,
cudaStream_t stream) {
using LowMGemmH20Traits =
ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::ColumnMajor, PerfConfigLowMH20>;
using LowMGemmHx00Traits =
......@@ -221,9 +186,6 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
using HighMGemmHx00Traits =
ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::RowMajor, PerfConfigHighMHx00>;
const std::string H20_device_type_str("NVIDIA H20");
bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
if (!is_h20_device) {
launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmHx00Traits>(
out_ptrs,
......@@ -236,7 +198,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d,
layout_sfb,
layout_sfa,
lm_problem_sizes);
lm_problem_sizes,
stream);
} else {
launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmH20Traits>(
out_ptrs,
......@@ -249,7 +212,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d,
layout_sfb,
layout_sfa,
lm_problem_sizes);
lm_problem_sizes,
stream);
}
if (!is_h20_device) {
......@@ -264,7 +228,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d,
layout_sfb,
layout_sfa,
mm_problem_sizes);
mm_problem_sizes,
stream);
} else {
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>(
out_ptrs,
......@@ -277,7 +242,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d,
layout_sfa,
layout_sfb,
mm_problem_sizes);
mm_problem_sizes,
stream);
}
if (!is_h20_device) {
......@@ -292,7 +258,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d,
layout_sfa,
layout_sfb,
hm_problem_sizes);
hm_problem_sizes,
stream);
} else {
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmH20Traits>(
out_ptrs,
......@@ -305,7 +272,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d,
layout_sfa,
layout_sfb,
hm_problem_sizes);
hm_problem_sizes,
stream);
}
}
......
......@@ -244,7 +244,7 @@ from sgl_kernel.elementwise import (
rmsnorm,
silu_and_mul,
)
from sgl_kernel.expert_specilization import es_fp8_blockwise_scaled_grouped_mm
from sgl_kernel.expert_specialization import es_fp8_blockwise_scaled_grouped_mm
from sgl_kernel.fused_moe import fused_marlin_moe
from sgl_kernel.gemm import (
awq_dequantize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment