Unverified Commit 6c01844f authored by Qi Yuhang's avatar Qi Yuhang Committed by GitHub
Browse files

[sgl-kernel][3/N]Support Expert Specialization Grouped GEMM (#11674)

parent f226d3da
......@@ -133,6 +133,7 @@ def bench_es(
d_strides = torch.full(
(num_groups,), c_out.stride(0), device=device, dtype=torch.int64
)
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
def run_cutlass():
es_fp8_blockwise_scaled_grouped_mm(
......@@ -146,6 +147,7 @@ def bench_es(
d_strides,
problem_sizes,
expert_offsets[:-1],
workspace,
)
run_cutlass()
......
......@@ -537,7 +537,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
*/
m.def(
"es_fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
"stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets) -> ()");
"stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets, Tensor workspace) -> "
"()");
m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm);
}
......
......@@ -40,7 +40,8 @@ void es_fp8_blockwise_scaled_grouped_mm(
const torch::Tensor& stride_b,
const torch::Tensor& stride_d,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets) {
const torch::Tensor& expert_offsets,
const torch::Tensor& workspace) {
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
......@@ -135,6 +136,7 @@ void es_fp8_blockwise_scaled_grouped_mm(
lm_problem_sizes,
mm_problem_sizes,
hm_problem_sizes,
workspace,
is_h20_device,
stream);
} else if (output.dtype() == torch::kFloat16) {
......@@ -152,6 +154,7 @@ void es_fp8_blockwise_scaled_grouped_mm(
lm_problem_sizes,
mm_problem_sizes,
hm_problem_sizes,
workspace,
is_h20_device,
stream);
} else {
......
......@@ -98,6 +98,7 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes,
const torch::Tensor& workspace,
cudaStream_t stream) {
using ElementA = typename GemmTraits::ElementA;
using StrideA = typename GemmTraits::StrideA;
......@@ -143,10 +144,6 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
torch::TensorOptions options_uint8 = torch::TensorOptions().dtype(torch::kUInt8).device(out_ptrs.device());
size_t workspace_size = gemm_op.get_workspace_size(args);
torch::Tensor workspace = torch::empty(workspace_size, options_uint8);
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
......@@ -169,6 +166,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
const torch::Tensor& lm_problem_sizes,
const torch::Tensor& mm_problem_sizes,
const torch::Tensor& hm_problem_sizes,
const torch::Tensor& workspace,
bool is_h20_device,
cudaStream_t stream) {
using LowMGemmH20Traits =
......@@ -199,6 +197,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfb,
layout_sfa,
lm_problem_sizes,
workspace,
stream);
} else {
launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmH20Traits>(
......@@ -213,6 +212,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfb,
layout_sfa,
lm_problem_sizes,
workspace,
stream);
}
......@@ -229,6 +229,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfb,
layout_sfa,
mm_problem_sizes,
workspace,
stream);
} else {
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>(
......@@ -243,6 +244,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfa,
layout_sfb,
mm_problem_sizes,
workspace,
stream);
}
......@@ -259,6 +261,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfa,
layout_sfb,
hm_problem_sizes,
workspace,
stream);
} else {
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmH20Traits>(
......@@ -273,6 +276,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfa,
layout_sfb,
hm_problem_sizes,
workspace,
stream);
}
}
......
......@@ -835,4 +835,5 @@ void es_fp8_blockwise_scaled_grouped_mm(
const torch::Tensor& stride_b,
const torch::Tensor& stride_d,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets);
const torch::Tensor& expert_offsets,
const torch::Tensor& workspace);
......@@ -12,6 +12,7 @@ def es_fp8_blockwise_scaled_grouped_mm(
stride_d,
problem_sizes,
expert_offsets,
workspace,
):
torch.ops.sgl_kernel.es_fp8_blockwise_scaled_grouped_mm.default(
output,
......@@ -24,4 +25,5 @@ def es_fp8_blockwise_scaled_grouped_mm(
stride_d,
problem_sizes,
expert_offsets,
workspace,
)
......@@ -168,7 +168,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
].t() # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
b_scale_stack = b_scale_stack.transpose(1, 2)
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
a_strides = torch.full(
(num_experts,), a_stack.stride(0), device=device, dtype=torch.int64
......@@ -188,6 +188,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
d_strides,
problem_sizes,
expert_offsets[:-1],
workspace,
)
for g in range(num_experts):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment