Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dc48c4c0
Unverified
Commit
dc48c4c0
authored
Oct 14, 2025
by
Qi Yuhang
Committed by
GitHub
Oct 13, 2025
Browse files
[sgl-kernel][2/N]Support Expert Specialization Grouped GEMM (#11534)
parent
6dc9ca8c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
112 additions
and
106 deletions
+112
-106
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu
+58
-20
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh
.../csrc/expert_specialization/es_fp8_blockwise_launcher.cuh
+53
-85
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-1
sgl-kernel/python/sgl_kernel/expert_specialization.py
sgl-kernel/python/sgl_kernel/expert_specialization.py
+0
-0
No files found.
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu
View file @
dc48c4c0
...
@@ -68,24 +68,58 @@ void es_fp8_blockwise_scaled_grouped_mm(
...
@@ -68,24 +68,58 @@ void es_fp8_blockwise_scaled_grouped_mm(
torch
::
Tensor
lm_problem_sizes
=
torch
::
empty
({
num_experts
,
3
},
options_int32
);
torch
::
Tensor
lm_problem_sizes
=
torch
::
empty
({
num_experts
,
3
},
options_int32
);
torch
::
Tensor
mm_problem_sizes
=
torch
::
empty
({
num_experts
,
3
},
options_int32
);
torch
::
Tensor
mm_problem_sizes
=
torch
::
empty
({
num_experts
,
3
},
options_int32
);
torch
::
Tensor
hm_problem_sizes
=
torch
::
empty
({
num_experts
,
3
},
options_int32
);
torch
::
Tensor
hm_problem_sizes
=
torch
::
empty
({
num_experts
,
3
},
options_int32
);
expert_specialization
::
es_sm90_fp8_blockwise_scaled_group_mm_pre_compute
(
out_ptrs
,
const
std
::
string
H20_device_type_str
(
"NVIDIA H20"
);
a_ptrs
,
bool
is_h20_device
=
std
::
string
(
at
::
cuda
::
getCurrentDeviceProperties
()
->
name
)
==
H20_device_type_str
;
b_ptrs
,
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
a
.
get_device
()};
a_scales_ptrs
,
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
b_scales_ptrs
,
layout_sfa
,
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
layout_sfb
,
expert_specialization
::
es_sm90_fp8_blockwise_scaled_group_mm_pre_compute
<
cutlass
::
bfloat16_t
>
(
lm_problem_sizes
,
out_ptrs
,
mm_problem_sizes
,
a_ptrs
,
hm_problem_sizes
,
b_ptrs
,
output
,
a_scales_ptrs
,
a
,
b_scales_ptrs
,
b
,
layout_sfa
,
scales_a
,
layout_sfb
,
scales_b
,
lm_problem_sizes
,
problem_sizes
,
mm_problem_sizes
,
expert_offsets
);
hm_problem_sizes
,
output
,
a
,
b
,
scales_a
,
scales_b
,
problem_sizes
,
expert_offsets
,
is_h20_device
,
stream
);
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
expert_specialization
::
es_sm90_fp8_blockwise_scaled_group_mm_pre_compute
<
cutlass
::
half_t
>
(
out_ptrs
,
a_ptrs
,
b_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
layout_sfa
,
layout_sfb
,
lm_problem_sizes
,
mm_problem_sizes
,
hm_problem_sizes
,
output
,
a
,
b
,
scales_a
,
scales_b
,
problem_sizes
,
expert_offsets
,
is_h20_device
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Invalid output type (must be float16 or bfloat16)"
);
}
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
expert_specialization
::
es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype
<
cutlass
::
bfloat16_t
>
(
expert_specialization
::
es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype
<
cutlass
::
bfloat16_t
>
(
out_ptrs
,
out_ptrs
,
...
@@ -100,7 +134,9 @@ void es_fp8_blockwise_scaled_grouped_mm(
...
@@ -100,7 +134,9 @@ void es_fp8_blockwise_scaled_grouped_mm(
layout_sfb
,
layout_sfb
,
lm_problem_sizes
,
lm_problem_sizes
,
mm_problem_sizes
,
mm_problem_sizes
,
hm_problem_sizes
);
hm_problem_sizes
,
is_h20_device
,
stream
);
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
expert_specialization
::
es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype
<
cutlass
::
half_t
>
(
expert_specialization
::
es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype
<
cutlass
::
half_t
>
(
out_ptrs
,
out_ptrs
,
...
@@ -115,7 +151,9 @@ void es_fp8_blockwise_scaled_grouped_mm(
...
@@ -115,7 +151,9 @@ void es_fp8_blockwise_scaled_grouped_mm(
layout_sfb
,
layout_sfb
,
lm_problem_sizes
,
lm_problem_sizes
,
mm_problem_sizes
,
mm_problem_sizes
,
hm_problem_sizes
);
hm_problem_sizes
,
is_h20_device
,
stream
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Invalid output type (must be float16 or bfloat16)"
);
TORCH_CHECK
(
false
,
"Invalid output type (must be float16 or bfloat16)"
);
}
}
...
...
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh
View file @
dc48c4c0
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <torch/all.h>
#include <cassert>
#include <iostream>
#include <iostream>
#include <string>
#include <string>
...
@@ -14,6 +15,7 @@ namespace expert_specialization {
...
@@ -14,6 +15,7 @@ namespace expert_specialization {
using
namespace
cute
;
using
namespace
cute
;
template
<
typename
T
>
void
es_sm90_fp8_blockwise_scaled_group_mm_pre_compute
(
void
es_sm90_fp8_blockwise_scaled_group_mm_pre_compute
(
// Output
// Output
torch
::
Tensor
&
out_ptrs
,
torch
::
Tensor
&
out_ptrs
,
...
@@ -33,15 +35,14 @@ void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
...
@@ -33,15 +35,14 @@ void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
expert_offsets
)
{
torch
::
Tensor
const
&
expert_offsets
,
bool
is_h20_device
,
cudaStream_t
stream
)
{
TORCH_CHECK
(
a_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
a_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
const
std
::
string
H20_device_type_str
(
"NVIDIA H20"
);
bool
is_h20_device
=
std
::
string
(
at
::
cuda
::
getCurrentDeviceProperties
()
->
name
)
==
H20_device_type_str
;
// Creat Scale Factor Layout Functor
// Creat Scale Factor Layout Functor
using
LayoutSFA
=
typename
PerfConfigMiddleMH20
::
LayoutSFA
;
using
LayoutSFA
=
typename
PerfConfigMiddleMH20
::
LayoutSFA
;
using
LayoutSFB
=
typename
PerfConfigMiddleMH20
::
LayoutSFB
;
using
LayoutSFB
=
typename
PerfConfigMiddleMH20
::
LayoutSFB
;
...
@@ -49,74 +50,38 @@ void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
...
@@ -49,74 +50,38 @@ void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
reinterpret_cast
<
LayoutSFA
*>
(
layout_sfa
.
data_ptr
()),
reinterpret_cast
<
LayoutSFB
*>
(
layout_sfb
.
data_ptr
()));
reinterpret_cast
<
LayoutSFA
*>
(
layout_sfa
.
data_ptr
()),
reinterpret_cast
<
LayoutSFB
*>
(
layout_sfb
.
data_ptr
()));
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a_tensors
.
device
().
index
());
TORCH_CHECK
(
num_experts
<=
1024
,
"Expert more than 1024"
);
// Max threads per block is 1024
// Dispatch
if
(
out_tensors
.
dtype
()
==
torch
::
kBFloat16
)
{
struct
Fp8BlockwiseGroupedGemmOffsetFunctor
<
cutlass
::
float_e4m3_t
,
float
,
T
>
of
(
struct
Fp8BlockwiseGroupedGemmOffsetFunctor
<
cutlass
::
float_e4m3_t
,
float
,
cutlass
::
bfloat16_t
>
of
(
static_cast
<
int
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
int
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
cutlass
::
float_e4m3_t
*>
(
a_tensors
.
data_ptr
()),
static_cast
<
cutlass
::
float_e4m3_t
*>
(
a_tensors
.
data_ptr
()),
static_cast
<
cutlass
::
float_e4m3_t
*>
(
b_tensors
.
data_ptr
()),
static_cast
<
cutlass
::
float_e4m3_t
*>
(
b_tensors
.
data_ptr
()),
static_cast
<
T
*>
(
out_tensors
.
data_ptr
()),
static_cast
<
cutlass
::
bfloat16_t
*>
(
out_tensors
.
data_ptr
()),
static_cast
<
float
*>
(
a_scales
.
data_ptr
()),
static_cast
<
float
*>
(
a_scales
.
data_ptr
()),
static_cast
<
float
*>
(
b_scales
.
data_ptr
()),
static_cast
<
float
*>
(
b_scales
.
data_ptr
()),
static_cast
<
cutlass
::
float_e4m3_t
**>
(
a_ptrs
.
data_ptr
()),
static_cast
<
cutlass
::
float_e4m3_t
**>
(
a_ptrs
.
data_ptr
()),
static_cast
<
cutlass
::
float_e4m3_t
**>
(
b_ptrs
.
data_ptr
()),
static_cast
<
cutlass
::
float_e4m3_t
**>
(
b_ptrs
.
data_ptr
()),
static_cast
<
float
**>
(
a_scales_ptrs
.
data_ptr
()),
static_cast
<
float
**>
(
a_scales_ptrs
.
data_ptr
()),
static_cast
<
float
**>
(
b_scales_ptrs
.
data_ptr
()),
static_cast
<
float
**>
(
b_scales_ptrs
.
data_ptr
()),
static_cast
<
T
**>
(
out_ptrs
.
data_ptr
()));
static_cast
<
cutlass
::
bfloat16_t
**>
(
out_ptrs
.
data_ptr
()));
if
(
!
is_h20_device
)
{
if
(
!
is_h20_device
)
{
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigLowMHx00
>
lm_psf
(
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigLowMHx00
>
lm_psf
(
static_cast
<
int
*>
(
lm_problem_sizes
.
data_ptr
()));
static_cast
<
int
*>
(
lm_problem_sizes
.
data_ptr
()));
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigMiddleMHx00
>
mm_psf
(
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigMiddleMHx00
>
mm_psf
(
static_cast
<
int
*>
(
mm_problem_sizes
.
data_ptr
()));
static_cast
<
int
*>
(
mm_problem_sizes
.
data_ptr
()));
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigHighMHx00
>
hm_psf
(
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigHighMHx00
>
hm_psf
(
static_cast
<
int
*>
(
hm_problem_sizes
.
data_ptr
()));
static_cast
<
int
*>
(
hm_problem_sizes
.
data_ptr
()));
groupedGemmPreComputeKernel
<<<
1
,
num_experts
,
0
,
stream
>>>
(
groupedGemmPreComputeKernel
<<<
1
,
num_experts
,
0
,
stream
>>>
(
static_cast
<
int
*>
(
problem_sizes
.
data_ptr
()),
of
,
sf_layout
,
lm_psf
,
mm_psf
,
hm_psf
);
static_cast
<
int
*>
(
problem_sizes
.
data_ptr
()),
of
,
sf_layout
,
lm_psf
,
mm_psf
,
hm_psf
);
}
else
{
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigLowMH20
>
lm_psf
(
static_cast
<
int
*>
(
lm_problem_sizes
.
data_ptr
()));
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigMiddleMH20
>
mm_psf
(
static_cast
<
int
*>
(
mm_problem_sizes
.
data_ptr
()));
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigHighMH20
>
hm_psf
(
static_cast
<
int
*>
(
hm_problem_sizes
.
data_ptr
()));
groupedGemmPreComputeKernel
<<<
1
,
num_experts
,
0
,
stream
>>>
(
static_cast
<
int
*>
(
problem_sizes
.
data_ptr
()),
of
,
sf_layout
,
lm_psf
,
mm_psf
,
hm_psf
);
}
}
else
if
(
out_tensors
.
dtype
()
==
torch
::
kFloat16
)
{
struct
Fp8BlockwiseGroupedGemmOffsetFunctor
<
cutlass
::
float_e4m3_t
,
float
,
cutlass
::
half_t
>
of
(
static_cast
<
int
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
cutlass
::
float_e4m3_t
*>
(
a_tensors
.
data_ptr
()),
static_cast
<
cutlass
::
float_e4m3_t
*>
(
b_tensors
.
data_ptr
()),
static_cast
<
cutlass
::
half_t
*>
(
out_tensors
.
data_ptr
()),
static_cast
<
float
*>
(
a_scales
.
data_ptr
()),
static_cast
<
float
*>
(
b_scales
.
data_ptr
()),
static_cast
<
cutlass
::
float_e4m3_t
**>
(
a_ptrs
.
data_ptr
()),
static_cast
<
cutlass
::
float_e4m3_t
**>
(
b_ptrs
.
data_ptr
()),
static_cast
<
float
**>
(
a_scales_ptrs
.
data_ptr
()),
static_cast
<
float
**>
(
b_scales_ptrs
.
data_ptr
()),
static_cast
<
cutlass
::
half_t
**>
(
out_ptrs
.
data_ptr
()));
if
(
!
is_h20_device
)
{
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigLowMHx00
>
lm_psf
(
static_cast
<
int
*>
(
lm_problem_sizes
.
data_ptr
()));
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigMiddleMHx00
>
mm_psf
(
static_cast
<
int
*>
(
mm_problem_sizes
.
data_ptr
()));
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigHighMHx00
>
hm_psf
(
static_cast
<
int
*>
(
hm_problem_sizes
.
data_ptr
()));
groupedGemmPreComputeKernel
<<<
1
,
num_experts
,
0
,
stream
>>>
(
static_cast
<
int
*>
(
problem_sizes
.
data_ptr
()),
of
,
sf_layout
,
lm_psf
,
mm_psf
,
hm_psf
);
}
else
{
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigLowMH20
>
lm_psf
(
static_cast
<
int
*>
(
lm_problem_sizes
.
data_ptr
()));
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigMiddleMH20
>
mm_psf
(
static_cast
<
int
*>
(
mm_problem_sizes
.
data_ptr
()));
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigHighMH20
>
hm_psf
(
static_cast
<
int
*>
(
hm_problem_sizes
.
data_ptr
()));
groupedGemmPreComputeKernel
<<<
1
,
num_experts
,
0
,
stream
>>>
(
static_cast
<
int
*>
(
problem_sizes
.
data_ptr
()),
of
,
sf_layout
,
lm_psf
,
mm_psf
,
hm_psf
);
}
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Invalid output type (must be float16 or bfloat16)"
);
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigLowMH20
>
lm_psf
(
static_cast
<
int
*>
(
lm_problem_sizes
.
data_ptr
()));
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigMiddleMH20
>
mm_psf
(
static_cast
<
int
*>
(
mm_problem_sizes
.
data_ptr
()));
struct
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor
<
PerfConfigHighMH20
>
hm_psf
(
static_cast
<
int
*>
(
hm_problem_sizes
.
data_ptr
()));
groupedGemmPreComputeKernel
<<<
1
,
num_experts
,
0
,
stream
>>>
(
static_cast
<
int
*>
(
problem_sizes
.
data_ptr
()),
of
,
sf_layout
,
lm_psf
,
mm_psf
,
hm_psf
);
}
}
}
}
...
@@ -132,7 +97,8 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
...
@@ -132,7 +97,8 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
const
torch
::
Tensor
&
stride_d
,
const
torch
::
Tensor
&
stride_d
,
const
torch
::
Tensor
&
layout_sfa
,
const
torch
::
Tensor
&
layout_sfa
,
const
torch
::
Tensor
&
layout_sfb
,
const
torch
::
Tensor
&
layout_sfb
,
const
torch
::
Tensor
&
problem_sizes
)
{
const
torch
::
Tensor
&
problem_sizes
,
cudaStream_t
stream
)
{
using
ElementA
=
typename
GemmTraits
::
ElementA
;
using
ElementA
=
typename
GemmTraits
::
ElementA
;
using
StrideA
=
typename
GemmTraits
::
StrideA
;
using
StrideA
=
typename
GemmTraits
::
StrideA
;
using
ElementB
=
typename
GemmTraits
::
ElementB
;
using
ElementB
=
typename
GemmTraits
::
ElementB
;
...
@@ -174,9 +140,6 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
...
@@ -174,9 +140,6 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
epilogue_args
,
epilogue_args
,
hw_info
};
hw_info
};
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
a_ptrs
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a_ptrs
.
get_device
());
auto
can_implement_status
=
gemm_op
.
can_implement
(
args
);
auto
can_implement_status
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to implement GEMM"
);
TORCH_CHECK
(
can_implement_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to implement GEMM"
);
...
@@ -205,7 +168,9 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
...
@@ -205,7 +168,9 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
const
torch
::
Tensor
&
layout_sfb
,
const
torch
::
Tensor
&
layout_sfb
,
const
torch
::
Tensor
&
lm_problem_sizes
,
const
torch
::
Tensor
&
lm_problem_sizes
,
const
torch
::
Tensor
&
mm_problem_sizes
,
const
torch
::
Tensor
&
mm_problem_sizes
,
const
torch
::
Tensor
&
hm_problem_sizes
)
{
const
torch
::
Tensor
&
hm_problem_sizes
,
bool
is_h20_device
,
cudaStream_t
stream
)
{
using
LowMGemmH20Traits
=
using
LowMGemmH20Traits
=
ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits
<
OutType
,
cutlass
::
layout
::
ColumnMajor
,
PerfConfigLowMH20
>
;
ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits
<
OutType
,
cutlass
::
layout
::
ColumnMajor
,
PerfConfigLowMH20
>
;
using
LowMGemmHx00Traits
=
using
LowMGemmHx00Traits
=
...
@@ -221,9 +186,6 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
...
@@ -221,9 +186,6 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
using
HighMGemmHx00Traits
=
using
HighMGemmHx00Traits
=
ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits
<
OutType
,
cutlass
::
layout
::
RowMajor
,
PerfConfigHighMHx00
>
;
ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits
<
OutType
,
cutlass
::
layout
::
RowMajor
,
PerfConfigHighMHx00
>
;
const
std
::
string
H20_device_type_str
(
"NVIDIA H20"
);
bool
is_h20_device
=
std
::
string
(
at
::
cuda
::
getCurrentDeviceProperties
()
->
name
)
==
H20_device_type_str
;
if
(
!
is_h20_device
)
{
if
(
!
is_h20_device
)
{
launch_sm90_fp8_blockwise_scaled_group_mm
<
LowMGemmHx00Traits
>
(
launch_sm90_fp8_blockwise_scaled_group_mm
<
LowMGemmHx00Traits
>
(
out_ptrs
,
out_ptrs
,
...
@@ -236,7 +198,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
...
@@ -236,7 +198,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d
,
stride_d
,
layout_sfb
,
layout_sfb
,
layout_sfa
,
layout_sfa
,
lm_problem_sizes
);
lm_problem_sizes
,
stream
);
}
else
{
}
else
{
launch_sm90_fp8_blockwise_scaled_group_mm
<
LowMGemmH20Traits
>
(
launch_sm90_fp8_blockwise_scaled_group_mm
<
LowMGemmH20Traits
>
(
out_ptrs
,
out_ptrs
,
...
@@ -249,7 +212,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
...
@@ -249,7 +212,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d
,
stride_d
,
layout_sfb
,
layout_sfb
,
layout_sfa
,
layout_sfa
,
lm_problem_sizes
);
lm_problem_sizes
,
stream
);
}
}
if
(
!
is_h20_device
)
{
if
(
!
is_h20_device
)
{
...
@@ -264,7 +228,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
...
@@ -264,7 +228,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d
,
stride_d
,
layout_sfb
,
layout_sfb
,
layout_sfa
,
layout_sfa
,
mm_problem_sizes
);
mm_problem_sizes
,
stream
);
}
else
{
}
else
{
launch_sm90_fp8_blockwise_scaled_group_mm
<
HighMGemmHx00Traits
>
(
launch_sm90_fp8_blockwise_scaled_group_mm
<
HighMGemmHx00Traits
>
(
out_ptrs
,
out_ptrs
,
...
@@ -277,7 +242,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
...
@@ -277,7 +242,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d
,
stride_d
,
layout_sfa
,
layout_sfa
,
layout_sfb
,
layout_sfb
,
mm_problem_sizes
);
mm_problem_sizes
,
stream
);
}
}
if
(
!
is_h20_device
)
{
if
(
!
is_h20_device
)
{
...
@@ -292,7 +258,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
...
@@ -292,7 +258,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d
,
stride_d
,
layout_sfa
,
layout_sfa
,
layout_sfb
,
layout_sfb
,
hm_problem_sizes
);
hm_problem_sizes
,
stream
);
}
else
{
}
else
{
launch_sm90_fp8_blockwise_scaled_group_mm
<
HighMGemmH20Traits
>
(
launch_sm90_fp8_blockwise_scaled_group_mm
<
HighMGemmH20Traits
>
(
out_ptrs
,
out_ptrs
,
...
@@ -305,7 +272,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
...
@@ -305,7 +272,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
stride_d
,
stride_d
,
layout_sfa
,
layout_sfa
,
layout_sfb
,
layout_sfb
,
hm_problem_sizes
);
hm_problem_sizes
,
stream
);
}
}
}
}
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
dc48c4c0
...
@@ -244,7 +244,7 @@ from sgl_kernel.elementwise import (
...
@@ -244,7 +244,7 @@ from sgl_kernel.elementwise import (
rmsnorm
,
rmsnorm
,
silu_and_mul
,
silu_and_mul
,
)
)
from
sgl_kernel.expert_specilization
import
es_fp8_blockwise_scaled_grouped_mm
from
sgl_kernel.expert_speci
a
lization
import
es_fp8_blockwise_scaled_grouped_mm
from
sgl_kernel.fused_moe
import
fused_marlin_moe
from
sgl_kernel.fused_moe
import
fused_marlin_moe
from
sgl_kernel.gemm
import
(
from
sgl_kernel.gemm
import
(
awq_dequantize
,
awq_dequantize
,
...
...
sgl-kernel/python/sgl_kernel/expert_specilization.py
→
sgl-kernel/python/sgl_kernel/expert_speci
a
lization.py
View file @
dc48c4c0
File moved
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment