"src/vscode:/vscode.git/clone" did not exist on "3cf4f9c7353b3d42c6b6b527df5c0359600bb197"
Unverified Commit 0fc54b97 authored by kousakawang's avatar kousakawang Committed by GitHub
Browse files

[fix]: fix cutlass moe ut and and Opt H20 cutlass groupGemm performance (#9272)


Co-authored-by: default avatarwanghanpei <wanghanpei@bytedance.com>
parent b3c1f2e4
......@@ -153,9 +153,8 @@ def run_test(tp_size, batch_size, model_config, check=False):
x,
w1,
w2,
topk_weights,
topk_ids,
inplace=False, # Use False for benchmarking to avoid side effects if run multiple times
(topk_weights, topk_ids, "dummy"),
inplace=False,
activation="silu", # Assuming SiLU activation common in MoEs
use_fp8_w8a8=True,
w1_scale=w1_scale,
......@@ -221,8 +220,7 @@ def run_test(tp_size, batch_size, model_config, check=False):
x,
w1, # Original shape
w2, # Original shape
topk_weights,
topk_ids,
(topk_weights, topk_ids, "dummy"),
inplace=False, # Important: Use False to get output tensor
activation="silu",
use_fp8_w8a8=True,
......@@ -266,7 +264,7 @@ if __name__ == "__main__":
"--batch-sizes",
type=int,
nargs="+",
default=[1, 4, 8, 16, 32, 64, 128, 256, 512], # Adjusted default
default=[1, 4, 8, 16, 32, 64, 128, 256, 512, 1024], # Adjusted default
help="List of batch sizes to test",
)
parser.add_argument("--check", action="store_true", help="Enable check mode")
......
......@@ -437,6 +437,34 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
}
}
#define JOIN_STRUCT_PP_NAME(m, n, k, a, b, c) sm90_fp8_pp_config##_##m##_##n##_##k##_##a##_##b##_##c
#define JOIN_STRUCT_CO_NAME(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c
#define GENERATE_SM90_FP8_PP_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_PP_NAME(M, N, K, A, B, C) { \
using ElementA = cutlass::float_e4m3_t; \
using MmaTileShape = Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \
};
#define GENERATE_SM90_FP8_CO_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_CO_NAME(M, N, K, A, B, C) { \
using ElementA = cutlass::float_e4m3_t; \
using MmaTileShape = Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \
};
template <typename OutType>
void sm90_fp8_blockwise_group_mm_dispatch_shape(
torch::Tensor& output,
......@@ -481,13 +509,24 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
// [NOTE] Tuned for H20
GENERATE_SM90_FP8_PP_CONFIG(64, 128, 128, 1, 2, 1)
int num_experts = (int)expert_offsets.size(0);
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) {
// For H20 with K > 128, use Pingpong Schedule
run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>(
bool tuning_H20_kernel = getBoolEnv("SGL_TUNE_DEVICE_KERNEL");
const std::string H20_device_type_str = "NVIDIA H20";
bool is_h20 = isDeviceType(H20_device_type_str);
if (is_h20 && tuning_H20_kernel) {
using execute_gemm_config = sm90_fp8_pp_config_64_128_128_1_2_1;
run_get_group_gemm_starts<
execute_gemm_config::LayoutSFA,
execute_gemm_config::LayoutSFB,
execute_gemm_config::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
......@@ -503,7 +542,8 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
layout_sfb,
problem_sizes,
problem_sizes_transpose);
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig0, cutlass::layout::RowMajor>(
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, execute_gemm_config, cutlass::layout::RowMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
......@@ -518,37 +558,71 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
expert_offsets,
workspace);
} else {
// For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
output,
scales_a,
scales_b,
layout_sfa,
layout_sfb,
problem_sizes,
problem_sizes_transpose);
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::RowMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) {
// For H20 with K > 128, use Pingpong Schedule
run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
output,
scales_a,
scales_b,
layout_sfa,
layout_sfb,
problem_sizes,
problem_sizes_transpose);
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig0, cutlass::layout::RowMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
} else {
// For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
output,
scales_a,
scales_b,
layout_sfa,
layout_sfb,
problem_sizes,
problem_sizes_transpose);
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::RowMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
}
}
}
......
......@@ -254,6 +254,25 @@ inline int getSMVersion() {
return sm_major * 10 + sm_minor;
}
inline bool isDeviceType(const std::string& device_type) {
int deviceCount;
CHECK_CUDA_SUCCESS(cudaGetDeviceCount(&deviceCount));
int device_id = -1;
if (deviceCount >= 1) {
CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id));
} else {
return false;
}
cudaDeviceProp prop;
CHECK_CUDA_SUCCESS(cudaGetDeviceProperties(&prop, device_id));
if (device_type == std::string(prop.name)) {
return true;
}
return false;
}
inline bool getBoolEnv(char const* name) {
char const* env = std::getenv(name);
return env && env[0] == '1' && env[1] == '\0';
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment