"docs/vscode:/vscode.git/clone" did not exist on "e2ca0733c63df70d0b756b8ced351c2fe69414c9"
Unverified Commit 8e9fb43d authored by Qi Yuhang's avatar Qi Yuhang Committed by GitHub
Browse files

Optimize Hopper CUTLASS FP8 Blockwise Grouped GEMM Kernel in Small K Scenario (#7782)

parent 83646089
...@@ -61,7 +61,12 @@ void launch_sm90_fp8_blockwise_scaled_group_mm( ...@@ -61,7 +61,12 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
using ArchTag = cutlass::arch::Sm90; using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp; using OperatorClass = cutlass::arch::OpClassTensorOp;
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator>; static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using CustomEVTIdentity = // acc
cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::
Sm90Compute<cutlass::epilogue::thread::Identity, ElementD, ElementAccumulator, RoundStyle>,
cutlass::epilogue::fusion::Sm90AccFetch>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, ArchTag,
...@@ -78,7 +83,7 @@ void launch_sm90_fp8_blockwise_scaled_group_mm( ...@@ -78,7 +83,7 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
LayoutC*, LayoutC*,
AlignmentC, AlignmentC,
typename ScheduleConfig::EpilogueSchedule, typename ScheduleConfig::EpilogueSchedule,
FusionOperation>::CollectiveOp; CustomEVTIdentity>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, ArchTag,
...@@ -452,7 +457,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( ...@@ -452,7 +457,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
const torch::Tensor& problem_sizes, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& expert_offsets,
const torch::Tensor& workspace) { const torch::Tensor& workspace) {
struct MmaConfig { struct MmaConfig0 {
using ElementA = cutlass::float_e4m3_t; using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_64, _128, _128>; using MmaTileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>; using ClusterShape = Shape<_2, _1, _1>;
...@@ -464,40 +469,86 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( ...@@ -464,40 +469,86 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
}; };
struct MmaConfig1 {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
int num_experts = (int)expert_offsets.size(0); int num_experts = (int)expert_offsets.size(0);
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int); torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
run_get_group_gemm_starts<MmaConfig::LayoutSFA, MmaConfig::LayoutSFB, MmaConfig::ScaleConfig>( if (a.size(1) > 128) {
expert_offsets, run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>(
a_ptrs, expert_offsets,
b_ptrs, a_ptrs,
out_ptrs, b_ptrs,
a_scales_ptrs, out_ptrs,
b_scales_ptrs, a_scales_ptrs,
a, b_scales_ptrs,
b, a,
output, b,
scales_a, output,
scales_b, scales_a,
layout_sfa, scales_b,
layout_sfb, layout_sfa,
problem_sizes, layout_sfb,
problem_sizes_transpose); problem_sizes,
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig, cutlass::layout::RowMajor>( problem_sizes_transpose);
out_ptrs, launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig0, cutlass::layout::RowMajor>(
a_ptrs, out_ptrs,
b_ptrs, a_ptrs,
a_scales_ptrs, b_ptrs,
b_scales_ptrs, a_scales_ptrs,
stride_a, b_scales_ptrs,
stride_b, stride_a,
stride_c, stride_b,
layout_sfa, stride_c,
layout_sfb, layout_sfa,
problem_sizes, layout_sfb,
expert_offsets, problem_sizes,
workspace); expert_offsets,
workspace);
} else {
// Small K
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
output,
scales_a,
scales_b,
layout_sfa,
layout_sfb,
problem_sizes,
problem_sizes_transpose);
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::RowMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
}
} }
/** /**
...@@ -641,7 +692,7 @@ void fp8_blockwise_scaled_grouped_mm( ...@@ -641,7 +692,7 @@ void fp8_blockwise_scaled_grouped_mm(
#endif #endif
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
if (sm_version == 90 && a.size(1) > 256) { if (sm_version == 90) {
if (output.scalar_type() == torch::kBFloat16) { if (output.scalar_type() == torch::kBFloat16) {
sm90_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>( sm90_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>(
output, output,
...@@ -687,8 +738,5 @@ void fp8_blockwise_scaled_grouped_mm( ...@@ -687,8 +738,5 @@ void fp8_blockwise_scaled_grouped_mm(
} }
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
can_implement, can_implement, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
"No implemented fp8_blockwise_scaled_mm for current compute capability or K size: ",
sm_version,
a.size(1));
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment