Unverified Commit ce112c07 authored by Qi Yuhang's avatar Qi Yuhang Committed by GitHub
Browse files

[sgl-kernel][4/N]Support Expert Specialization Grouped GEMM (#12080)

parent f7dc2f33
...@@ -126,7 +126,7 @@ struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> { ...@@ -126,7 +126,7 @@ struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> {
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {} Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
if (m <= 48) { if (m < 64) {
// Swap A/B // Swap A/B
problem_sizes[expert_id * 3 + 0] = n; problem_sizes[expert_id * 3 + 0] = n;
problem_sizes[expert_id * 3 + 1] = m; problem_sizes[expert_id * 3 + 1] = m;
...@@ -168,7 +168,7 @@ struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> { ...@@ -168,7 +168,7 @@ struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> {
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {} Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
if (m > 48 && m <= 96) { if (m >= 64 && m < 128) {
problem_sizes[expert_id * 3 + 0] = m; problem_sizes[expert_id * 3 + 0] = m;
problem_sizes[expert_id * 3 + 1] = n; problem_sizes[expert_id * 3 + 1] = n;
problem_sizes[expert_id * 3 + 2] = k; problem_sizes[expert_id * 3 + 2] = k;
...@@ -208,7 +208,7 @@ struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> { ...@@ -208,7 +208,7 @@ struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> {
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {} Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
if (m > 96) { if (m >= 128) {
problem_sizes[expert_id * 3 + 0] = m; problem_sizes[expert_id * 3 + 0] = m;
problem_sizes[expert_id * 3 + 1] = n; problem_sizes[expert_id * 3 + 1] = n;
problem_sizes[expert_id * 3 + 2] = k; problem_sizes[expert_id * 3 + 2] = k;
......
...@@ -232,7 +232,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( ...@@ -232,7 +232,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
workspace, workspace,
stream); stream);
} else { } else {
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>( launch_sm90_fp8_blockwise_scaled_group_mm<MiddleMGemmH20Traits>(
out_ptrs, out_ptrs,
a_ptrs, a_ptrs,
b_ptrs, b_ptrs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment