Unverified Commit 0fc54b97 authored by kousakawang's avatar kousakawang Committed by GitHub
Browse files

[fix]: fix cutlass moe ut and and Opt H20 cutlass groupGemm performance (#9272)


Co-authored-by: default avatarwanghanpei <wanghanpei@bytedance.com>
parent b3c1f2e4
...@@ -153,9 +153,8 @@ def run_test(tp_size, batch_size, model_config, check=False): ...@@ -153,9 +153,8 @@ def run_test(tp_size, batch_size, model_config, check=False):
x, x,
w1, w1,
w2, w2,
topk_weights, (topk_weights, topk_ids, "dummy"),
topk_ids, inplace=False,
inplace=False, # Use False for benchmarking to avoid side effects if run multiple times
activation="silu", # Assuming SiLU activation common in MoEs activation="silu", # Assuming SiLU activation common in MoEs
use_fp8_w8a8=True, use_fp8_w8a8=True,
w1_scale=w1_scale, w1_scale=w1_scale,
...@@ -221,8 +220,7 @@ def run_test(tp_size, batch_size, model_config, check=False): ...@@ -221,8 +220,7 @@ def run_test(tp_size, batch_size, model_config, check=False):
x, x,
w1, # Original shape w1, # Original shape
w2, # Original shape w2, # Original shape
topk_weights, (topk_weights, topk_ids, "dummy"),
topk_ids,
inplace=False, # Important: Use False to get output tensor inplace=False, # Important: Use False to get output tensor
activation="silu", activation="silu",
use_fp8_w8a8=True, use_fp8_w8a8=True,
...@@ -266,7 +264,7 @@ if __name__ == "__main__": ...@@ -266,7 +264,7 @@ if __name__ == "__main__":
"--batch-sizes", "--batch-sizes",
type=int, type=int,
nargs="+", nargs="+",
default=[1, 4, 8, 16, 32, 64, 128, 256, 512], # Adjusted default default=[1, 4, 8, 16, 32, 64, 128, 256, 512, 1024], # Adjusted default
help="List of batch sizes to test", help="List of batch sizes to test",
) )
parser.add_argument("--check", action="store_true", help="Enable check mode") parser.add_argument("--check", action="store_true", help="Enable check mode")
......
...@@ -437,6 +437,34 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( ...@@ -437,6 +437,34 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
} }
} }
#define JOIN_STRUCT_PP_NAME(m, n, k, a, b, c) sm90_fp8_pp_config##_##m##_##n##_##k##_##a##_##b##_##c
#define JOIN_STRUCT_CO_NAME(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c
#define GENERATE_SM90_FP8_PP_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_PP_NAME(M, N, K, A, B, C) { \
using ElementA = cutlass::float_e4m3_t; \
using MmaTileShape = Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \
};
#define GENERATE_SM90_FP8_CO_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_CO_NAME(M, N, K, A, B, C) { \
using ElementA = cutlass::float_e4m3_t; \
using MmaTileShape = Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \
};
template <typename OutType> template <typename OutType>
void sm90_fp8_blockwise_group_mm_dispatch_shape( void sm90_fp8_blockwise_group_mm_dispatch_shape(
torch::Tensor& output, torch::Tensor& output,
...@@ -481,10 +509,55 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( ...@@ -481,10 +509,55 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
}; };
// [NOTE] Tuned for H20
GENERATE_SM90_FP8_PP_CONFIG(64, 128, 128, 1, 2, 1)
int num_experts = (int)expert_offsets.size(0); int num_experts = (int)expert_offsets.size(0);
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int); torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
bool tuning_H20_kernel = getBoolEnv("SGL_TUNE_DEVICE_KERNEL");
const std::string H20_device_type_str = "NVIDIA H20";
bool is_h20 = isDeviceType(H20_device_type_str);
if (is_h20 && tuning_H20_kernel) {
using execute_gemm_config = sm90_fp8_pp_config_64_128_128_1_2_1;
run_get_group_gemm_starts<
execute_gemm_config::LayoutSFA,
execute_gemm_config::LayoutSFB,
execute_gemm_config::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
output,
scales_a,
scales_b,
layout_sfa,
layout_sfb,
problem_sizes,
problem_sizes_transpose);
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, execute_gemm_config, cutlass::layout::RowMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
} else {
if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) { if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) {
// For H20 with K > 128, use Pingpong Schedule // For H20 with K > 128, use Pingpong Schedule
run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>( run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>(
...@@ -550,6 +623,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( ...@@ -550,6 +623,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
expert_offsets, expert_offsets,
workspace); workspace);
} }
}
} }
/** /**
......
...@@ -254,6 +254,25 @@ inline int getSMVersion() { ...@@ -254,6 +254,25 @@ inline int getSMVersion() {
return sm_major * 10 + sm_minor; return sm_major * 10 + sm_minor;
} }
inline bool isDeviceType(const std::string& device_type) {
int deviceCount;
CHECK_CUDA_SUCCESS(cudaGetDeviceCount(&deviceCount));
int device_id = -1;
if (deviceCount >= 1) {
CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id));
} else {
return false;
}
cudaDeviceProp prop;
CHECK_CUDA_SUCCESS(cudaGetDeviceProperties(&prop, device_id));
if (device_type == std::string(prop.name)) {
return true;
}
return false;
}
inline bool getBoolEnv(char const* name) { inline bool getBoolEnv(char const* name) {
char const* env = std::getenv(name); char const* env = std::getenv(name);
return env && env[0] == '1' && env[1] == '\0'; return env && env[0] == '1' && env[1] == '\0';
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment