Unverified Commit 5fd311d3 authored by kousakawang's avatar kousakawang Committed by GitHub
Browse files

[code clean] add H20 cutlass groupGemm default config (#9333)


Co-authored-by: default avatarwanghanpei <wanghanpei@bytedance.com>
parent 53e2cd46
...@@ -437,34 +437,6 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( ...@@ -437,34 +437,6 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
} }
} }
#define JOIN_STRUCT_PP_NAME(m, n, k, a, b, c) sm90_fp8_pp_config##_##m##_##n##_##k##_##a##_##b##_##c
#define JOIN_STRUCT_CO_NAME(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c
#define GENERATE_SM90_FP8_PP_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_PP_NAME(M, N, K, A, B, C) { \
using ElementA = cutlass::float_e4m3_t; \
using MmaTileShape = Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \
};
#define GENERATE_SM90_FP8_CO_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_CO_NAME(M, N, K, A, B, C) { \
using ElementA = cutlass::float_e4m3_t; \
using MmaTileShape = Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \
};
template <typename OutType> template <typename OutType>
void sm90_fp8_blockwise_group_mm_dispatch_shape( void sm90_fp8_blockwise_group_mm_dispatch_shape(
torch::Tensor& output, torch::Tensor& output,
...@@ -509,20 +481,28 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( ...@@ -509,20 +481,28 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
}; };
// [NOTE] Tuned for H20 // [NOTE] default for H20
GENERATE_SM90_FP8_PP_CONFIG(64, 128, 128, 1, 2, 1) struct MmaConfigH20_default {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
int num_experts = (int)expert_offsets.size(0); int num_experts = (int)expert_offsets.size(0);
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int); torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
bool tuning_H20_kernel = getBoolEnv("SGL_TUNE_DEVICE_KERNEL");
const std::string H20_device_type_str = "NVIDIA H20"; const std::string H20_device_type_str = "NVIDIA H20";
bool is_h20 = isDeviceType(H20_device_type_str); bool is_h20_device = isDeviceType(H20_device_type_str);
if (is_h20 && tuning_H20_kernel) { if (is_h20_device) {
using execute_gemm_config = sm90_fp8_pp_config_64_128_128_1_2_1; using execute_gemm_config = MmaConfigH20_default;
run_get_group_gemm_starts< run_get_group_gemm_starts<
execute_gemm_config::LayoutSFA, execute_gemm_config::LayoutSFA,
execute_gemm_config::LayoutSFB, execute_gemm_config::LayoutSFB,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment