Unverified Commit f347ac6c authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Perf] Fuse stride preparation for NVFP4 cutlass_moe (#31837)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 05f47bd8
...@@ -62,7 +62,9 @@ __global__ void __get_group_gemm_starts( ...@@ -62,7 +62,9 @@ __global__ void __get_group_gemm_starts(
ElementSF* a_scales_base_as_int, ElementSF* b_scales_base_as_int, ElementSF* a_scales_base_as_int, ElementSF* b_scales_base_as_int,
ElementAccumulator* alphas_base_as_int, const int32_t* expert_offsets, ElementAccumulator* alphas_base_as_int, const int32_t* expert_offsets,
const int32_t* sf_offsets, const int32_t* problem_sizes_as_shapes, const int32_t* sf_offsets, const int32_t* problem_sizes_as_shapes,
const int K, const int N) { int64_t* a_strides, int64_t* b_strides, int64_t* c_strides,
const int64_t a_stride_val, const int64_t b_stride_val,
const int64_t c_stride_val, const int K, const int N) {
int64_t expert_id = threadIdx.x; int64_t expert_id = threadIdx.x;
if (expert_id >= gridDim.x * blockDim.x) { if (expert_id >= gridDim.x * blockDim.x) {
return; return;
...@@ -103,6 +105,11 @@ __global__ void __get_group_gemm_starts( ...@@ -103,6 +105,11 @@ __global__ void __get_group_gemm_starts(
// Shape of alpha = [E] // Shape of alpha = [E]
alpha_offsets[expert_id] = alphas_base_as_int + expert_id; alpha_offsets[expert_id] = alphas_base_as_int + expert_id;
// Initialize strides (constant across all experts, avoids separate kernels)
a_strides[expert_id] = a_stride_val;
b_strides[expert_id] = b_stride_val;
c_strides[expert_id] = c_stride_val;
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id; LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id; LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
...@@ -135,7 +142,11 @@ __global__ void __get_group_gemm_starts( ...@@ -135,7 +142,11 @@ __global__ void __get_group_gemm_starts(
static_cast<float*>(alphas.data_ptr()), \ static_cast<float*>(alphas.data_ptr()), \
static_cast<int32_t*>(expert_offsets.data_ptr()), \ static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<int32_t*>(sf_offsets.data_ptr()), \ static_cast<int32_t*>(sf_offsets.data_ptr()), \
static_cast<int32_t*>(problem_sizes.data_ptr()), K, N); \ static_cast<int32_t*>(problem_sizes.data_ptr()), \
static_cast<int64_t*>(a_strides.data_ptr()), \
static_cast<int64_t*>(b_strides.data_ptr()), \
static_cast<int64_t*>(c_strides.data_ptr()), a_stride_val, \
b_stride_val, c_stride_val, K, N); \
} }
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig> template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
...@@ -144,6 +155,9 @@ void run_get_group_gemm_starts( ...@@ -144,6 +155,9 @@ void run_get_group_gemm_starts(
const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts, const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts,
const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts, const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts,
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
const torch::Tensor& a_strides, const torch::Tensor& b_strides,
const torch::Tensor& c_strides, int64_t a_stride_val, int64_t b_stride_val,
int64_t c_stride_val,
/*these are used for their base addresses*/ /*these are used for their base addresses*/
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
torch::Tensor const& out_tensors, torch::Tensor const& a_scales, torch::Tensor const& out_tensors, torch::Tensor const& a_scales,
...@@ -269,17 +283,16 @@ void run_fp4_blockwise_scaled_group_mm_sm100( ...@@ -269,17 +283,16 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int); torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int); torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int); torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
torch::Tensor c_strides1 = torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
torch::full({num_experts}, output.stride(0), options_int); torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
torch::Tensor a_strides1 = torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
torch::full({num_experts}, a.stride(0) * 2, options_int);
torch::Tensor b_strides1 =
torch::full({num_experts}, b.stride(1) * 2, options_int);
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>( run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas, layout_sfa, layout_sfb, a_strides1, b_strides1, c_strides1,
expert_offsets, sf_offsets, problem_sizes, M, N, K); a.stride(0) * 2, b.stride(1) * 2, output.stride(0), a, b, output,
a_blockscale, b_blockscales, alphas, expert_offsets, sf_offsets,
problem_sizes, M, N, K);
// Create an instance of the GEMM // Create an instance of the GEMM
Gemm gemm_op; Gemm gemm_op;
...@@ -444,17 +457,16 @@ void run_fp4_blockwise_scaled_group_mm_sm120( ...@@ -444,17 +457,16 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int); torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int); torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int); torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
torch::Tensor c_strides1 = torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
torch::full({num_experts}, output.stride(0), options_int); torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
torch::Tensor a_strides1 = torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
torch::full({num_experts}, a.stride(0) * 2, options_int);
torch::Tensor b_strides1 =
torch::full({num_experts}, b.stride(1) * 2, options_int);
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>( run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas, layout_sfa, layout_sfb, a_strides1, b_strides1, c_strides1,
expert_offsets, sf_offsets, problem_sizes, M, N, K); a.stride(0) * 2, b.stride(1) * 2, output.stride(0), a, b, output,
a_blockscale, b_blockscales, alphas, expert_offsets, sf_offsets,
problem_sizes, M, N, K);
// Create an instance of the GEMM // Create an instance of the GEMM
Gemm gemm_op; Gemm gemm_op;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment