Unverified Commit d9def43d authored by Qi Yuhang's avatar Qi Yuhang Committed by GitHub
Browse files

[Perf]Use Cooperative Schedule for H100 & H200 & H800 in fp8_blockwise_scaled_grouped_mm (#8722)

parent e273aa6d
...@@ -485,7 +485,8 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( ...@@ -485,7 +485,8 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int); torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
if (a.size(1) > 128) { if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) {
// For H20 with K > 128, use Pingpong Schedule
run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>( run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>(
expert_offsets, expert_offsets,
a_ptrs, a_ptrs,
...@@ -517,7 +518,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( ...@@ -517,7 +518,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
expert_offsets, expert_offsets,
workspace); workspace);
} else { } else {
// Small K // For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>( run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
expert_offsets, expert_offsets,
a_ptrs, a_ptrs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment