Unverified Commit ce112c07 authored by Qi Yuhang's avatar Qi Yuhang Committed by GitHub
Browse files

[sgl-kernel][4/N]Support Expert Specialization Grouped GEMM (#12080)

parent f7dc2f33
......@@ -126,7 +126,7 @@ struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> {
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
if (m <= 48) {
if (m < 64) {
// Swap A/B
problem_sizes[expert_id * 3 + 0] = n;
problem_sizes[expert_id * 3 + 1] = m;
......@@ -168,7 +168,7 @@ struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> {
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
if (m > 48 && m <= 96) {
if (m >= 64 && m < 128) {
problem_sizes[expert_id * 3 + 0] = m;
problem_sizes[expert_id * 3 + 1] = n;
problem_sizes[expert_id * 3 + 2] = k;
......@@ -208,7 +208,7 @@ struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> {
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
if (m > 96) {
if (m >= 128) {
problem_sizes[expert_id * 3 + 0] = m;
problem_sizes[expert_id * 3 + 1] = n;
problem_sizes[expert_id * 3 + 2] = k;
......
......@@ -232,7 +232,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
workspace,
stream);
} else {
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>(
launch_sm90_fp8_blockwise_scaled_group_mm<MiddleMGemmH20Traits>(
out_ptrs,
a_ptrs,
b_ptrs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment