Unverified Commit 6c01844f authored by Qi Yuhang's avatar Qi Yuhang Committed by GitHub
Browse files

[sgl-kernel][3/N]Support Expert Specialization Grouped GEMM (#11674)

parent f226d3da
...@@ -133,6 +133,7 @@ def bench_es( ...@@ -133,6 +133,7 @@ def bench_es(
d_strides = torch.full( d_strides = torch.full(
(num_groups,), c_out.stride(0), device=device, dtype=torch.int64 (num_groups,), c_out.stride(0), device=device, dtype=torch.int64
) )
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
def run_cutlass(): def run_cutlass():
es_fp8_blockwise_scaled_grouped_mm( es_fp8_blockwise_scaled_grouped_mm(
...@@ -146,6 +147,7 @@ def bench_es( ...@@ -146,6 +147,7 @@ def bench_es(
d_strides, d_strides,
problem_sizes, problem_sizes,
expert_offsets[:-1], expert_offsets[:-1],
workspace,
) )
run_cutlass() run_cutlass()
......
...@@ -537,7 +537,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -537,7 +537,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
*/ */
m.def( m.def(
"es_fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor " "es_fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
"stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets) -> ()"); "stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets, Tensor workspace) -> "
"()");
m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm); m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm);
} }
......
...@@ -40,7 +40,8 @@ void es_fp8_blockwise_scaled_grouped_mm( ...@@ -40,7 +40,8 @@ void es_fp8_blockwise_scaled_grouped_mm(
const torch::Tensor& stride_b, const torch::Tensor& stride_b,
const torch::Tensor& stride_d, const torch::Tensor& stride_d,
const torch::Tensor& problem_sizes, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets) { const torch::Tensor& expert_offsets,
const torch::Tensor& workspace) {
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)"); TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
...@@ -135,6 +136,7 @@ void es_fp8_blockwise_scaled_grouped_mm( ...@@ -135,6 +136,7 @@ void es_fp8_blockwise_scaled_grouped_mm(
lm_problem_sizes, lm_problem_sizes,
mm_problem_sizes, mm_problem_sizes,
hm_problem_sizes, hm_problem_sizes,
workspace,
is_h20_device, is_h20_device,
stream); stream);
} else if (output.dtype() == torch::kFloat16) { } else if (output.dtype() == torch::kFloat16) {
...@@ -152,6 +154,7 @@ void es_fp8_blockwise_scaled_grouped_mm( ...@@ -152,6 +154,7 @@ void es_fp8_blockwise_scaled_grouped_mm(
lm_problem_sizes, lm_problem_sizes,
mm_problem_sizes, mm_problem_sizes,
hm_problem_sizes, hm_problem_sizes,
workspace,
is_h20_device, is_h20_device,
stream); stream);
} else { } else {
......
...@@ -98,6 +98,7 @@ void launch_sm90_fp8_blockwise_scaled_group_mm( ...@@ -98,6 +98,7 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb, const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes, const torch::Tensor& problem_sizes,
const torch::Tensor& workspace,
cudaStream_t stream) { cudaStream_t stream) {
using ElementA = typename GemmTraits::ElementA; using ElementA = typename GemmTraits::ElementA;
using StrideA = typename GemmTraits::StrideA; using StrideA = typename GemmTraits::StrideA;
...@@ -143,10 +144,6 @@ void launch_sm90_fp8_blockwise_scaled_group_mm( ...@@ -143,10 +144,6 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
auto can_implement_status = gemm_op.can_implement(args); auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM"); TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
torch::TensorOptions options_uint8 = torch::TensorOptions().dtype(torch::kUInt8).device(out_ptrs.device());
size_t workspace_size = gemm_op.get_workspace_size(args);
torch::Tensor workspace = torch::empty(workspace_size, options_uint8);
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream); auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
...@@ -169,6 +166,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( ...@@ -169,6 +166,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
const torch::Tensor& lm_problem_sizes, const torch::Tensor& lm_problem_sizes,
const torch::Tensor& mm_problem_sizes, const torch::Tensor& mm_problem_sizes,
const torch::Tensor& hm_problem_sizes, const torch::Tensor& hm_problem_sizes,
const torch::Tensor& workspace,
bool is_h20_device, bool is_h20_device,
cudaStream_t stream) { cudaStream_t stream) {
using LowMGemmH20Traits = using LowMGemmH20Traits =
...@@ -199,6 +197,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( ...@@ -199,6 +197,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfb, layout_sfb,
layout_sfa, layout_sfa,
lm_problem_sizes, lm_problem_sizes,
workspace,
stream); stream);
} else { } else {
launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmH20Traits>( launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmH20Traits>(
...@@ -213,6 +212,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( ...@@ -213,6 +212,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfb, layout_sfb,
layout_sfa, layout_sfa,
lm_problem_sizes, lm_problem_sizes,
workspace,
stream); stream);
} }
...@@ -229,6 +229,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( ...@@ -229,6 +229,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfb, layout_sfb,
layout_sfa, layout_sfa,
mm_problem_sizes, mm_problem_sizes,
workspace,
stream); stream);
} else { } else {
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>( launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>(
...@@ -243,6 +244,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( ...@@ -243,6 +244,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfa, layout_sfa,
layout_sfb, layout_sfb,
mm_problem_sizes, mm_problem_sizes,
workspace,
stream); stream);
} }
...@@ -259,6 +261,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( ...@@ -259,6 +261,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfa, layout_sfa,
layout_sfb, layout_sfb,
hm_problem_sizes, hm_problem_sizes,
workspace,
stream); stream);
} else { } else {
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmH20Traits>( launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmH20Traits>(
...@@ -273,6 +276,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( ...@@ -273,6 +276,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfa, layout_sfa,
layout_sfb, layout_sfb,
hm_problem_sizes, hm_problem_sizes,
workspace,
stream); stream);
} }
} }
......
...@@ -835,4 +835,5 @@ void es_fp8_blockwise_scaled_grouped_mm( ...@@ -835,4 +835,5 @@ void es_fp8_blockwise_scaled_grouped_mm(
const torch::Tensor& stride_b, const torch::Tensor& stride_b,
const torch::Tensor& stride_d, const torch::Tensor& stride_d,
const torch::Tensor& problem_sizes, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets); const torch::Tensor& expert_offsets,
const torch::Tensor& workspace);
...@@ -12,6 +12,7 @@ def es_fp8_blockwise_scaled_grouped_mm( ...@@ -12,6 +12,7 @@ def es_fp8_blockwise_scaled_grouped_mm(
stride_d, stride_d,
problem_sizes, problem_sizes,
expert_offsets, expert_offsets,
workspace,
): ):
torch.ops.sgl_kernel.es_fp8_blockwise_scaled_grouped_mm.default( torch.ops.sgl_kernel.es_fp8_blockwise_scaled_grouped_mm.default(
output, output,
...@@ -24,4 +25,5 @@ def es_fp8_blockwise_scaled_grouped_mm( ...@@ -24,4 +25,5 @@ def es_fp8_blockwise_scaled_grouped_mm(
stride_d, stride_d,
problem_sizes, problem_sizes,
expert_offsets, expert_offsets,
workspace,
) )
...@@ -168,7 +168,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): ...@@ -168,7 +168,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
].t() # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later ].t() # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major. b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
b_scale_stack = b_scale_stack.transpose(1, 2) b_scale_stack = b_scale_stack.transpose(1, 2)
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype) c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
a_strides = torch.full( a_strides = torch.full(
(num_experts,), a_stack.stride(0), device=device, dtype=torch.int64 (num_experts,), a_stack.stride(0), device=device, dtype=torch.int64
...@@ -188,6 +188,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): ...@@ -188,6 +188,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
d_strides, d_strides,
problem_sizes, problem_sizes,
expert_offsets[:-1], expert_offsets[:-1],
workspace,
) )
for g in range(num_experts): for g in range(num_experts):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment