Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6fc93575
Unverified
Commit
6fc93575
authored
May 16, 2025
by
Elfie Guo
Committed by
GitHub
May 16, 2025
Browse files
[2/2] Add python wrapper for CUTLASS FP8 Blockscale MoE Kernel. (#5694)
parent
839fb31e
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
895 additions
and
40 deletions
+895
-40
python/sglang/srt/layers/moe/cutlass_moe.py
python/sglang/srt/layers/moe/cutlass_moe.py
+207
-0
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+90
-0
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+6
-0
python/sglang/test/test_cutlass_moe.py
python/sglang/test/test_cutlass_moe.py
+278
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+7
-3
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
+111
-36
sgl-kernel/csrc/moe/prepare_moe_input.cu
sgl-kernel/csrc/moe/prepare_moe_input.cu
+128
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+18
-1
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/moe.py
sgl-kernel/python/sgl_kernel/moe.py
+36
-0
sgl-kernel/tests/test_fp8_blockwise_moe.py
sgl-kernel/tests/test_fp8_blockwise_moe.py
+12
-0
No files found.
python/sglang/srt/layers/moe/cutlass_moe.py
0 → 100755
View file @
6fc93575
"""Cutlass MoE kernel."""
import
functools
import
json
import
logging
import
os
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
from
sglang.srt.utils
import
is_cuda
_is_cuda
=
is_cuda
()
if
_is_cuda
:
import
sgl_kernel
from
sgl_kernel
import
(
fp8_blockwise_scaled_grouped_mm
,
prepare_moe_input
,
silu_and_mul
,
)
def
cutlass_fused_experts
(
a
:
torch
.
Tensor
,
w1_q
:
torch
.
Tensor
,
w2_q
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
a1_strides
:
torch
.
Tensor
,
c1_strides
:
torch
.
Tensor
,
a2_strides
:
torch
.
Tensor
,
c2_strides
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
a_ptrs
:
torch
.
Tensor
,
b_ptrs
:
torch
.
Tensor
,
out_ptrs
:
torch
.
Tensor
,
a_scales_ptrs
:
torch
.
Tensor
,
b_scales_ptrs
:
torch
.
Tensor
,
expert_offsets
:
torch
.
Tensor
,
problem_sizes1
:
torch
.
Tensor
,
problem_sizes2
:
torch
.
Tensor
,
use_fp8_blockscale
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations.
This function implements a Mixture of Experts (MoE) layer with a SwiGLU/SiLU
activation, leveraging custom kernels likely derived from CUTLASS principles
for grouped matrix multiplication (`fp8_blockwise_scaled_grouped_mm`) and
data preparation (`prepare_moe_input`, `silu_and_mul`).
It handles per-token routing, quantizes input activations to FP8 with
per-token scales, performs the expert computations using FP8 GEMMs with
pre-quantized FP8 weights (per-block scales), applies the SiLU activation,
and combines the results weighted by the router scores.
Args:
a (torch.Tensor): Input activations. Shape: `(m, k)`, where `m` is the total
number of tokens and `k` is the hidden size. Expected dtype: `torch.half`
or `torch.bfloat16`.
w1_q (torch.Tensor): Pre-quantized FP8 weight tensor for the first GEMM
(up-projection part of SwiGLU). Expected shape: `(E, k, n*2)`, where
`E` is the number of experts, `k` is the hidden size, and `n*2` is the
intermediate size (`I`). Expected dtype: `torch.float8_e4m3fn`.
Note: This shape implies weights are stored as (num_experts, hidden_size, intermediate_size).
w2_q (torch.Tensor): Pre-quantized FP8 weight tensor for the second GEMM
(down-projection). Expected shape: `(E, n, k)`, where `n` is half the
intermediate size (`I // 2`). Expected dtype: `torch.float8_e4m3fn`.
Note: This shape implies weights are stored as (num_experts, intermediate_size // 2, hidden_size).
w1_scale (torch.Tensor): Scales corresponding to `w1_q` (per-block scales).
Shape: `(E, num_blocks_n, num_blocks_k)`. Dtype: `torch.float32`.
w2_scale (torch.Tensor): Scales corresponding to `w2_q` (per-block scales).
Shape: `(E, num_blocks_k, num_blocks_n)`. Dtype: `torch.float32`.
topk_weights (torch.Tensor): Router weights for the selected top-k experts
for each token. Shape: `(m, topk)`. Dtype should ideally match `a`.
topk_ids (torch.Tensor): Indices of the selected top-k experts for each token.
Shape: `(m, topk)`. Dtype: `torch.int32`.
a1_strides (torch.Tensor): Stride information for the first GEMM's 'a' input.
Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
Note: Its exact usage within `fp8_blockwise_scaled_grouped_mm` needs clarification
as it's passed as both a_stride and b_stride in the first call.
c1_strides (torch.Tensor): Stride information for the first GEMM's 'c' output.
Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
a2_strides (torch.Tensor): Stride information for the second GEMM's 'a' input.
Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
Note: Its exact usage within `fp8_blockwise_scaled_grouped_mm` needs clarification
as it's passed as both a_stride and b_stride in the second call.
c2_strides (torch.Tensor): Stride information for the second GEMM's 'c' output.
Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
workspace (torch.Tensor): Reusable workspace for the underlying kernel.
a_ptrs (torch.Tensor): Pointers container for calculating offsets of the input activations for each expert.
b_ptrs (torch.Tensor): Pointers container for calculating offsets of the input weights for each expert.
out_ptrs (torch.Tensor): Pointers container for calculating offsets of the output activations for each expert.
a_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.
b_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.
use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with
block scaling. Currently, only `True` is supported. Defaults to `True`.
Returns:
torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`.
Raises:
AssertionError: If input shapes, dtypes, or flags are inconsistent or unsupported.
NotImplementedError: If CUDA is not available or `sgl_kernel` is not properly installed.
"""
assert
use_fp8_blockscale
,
"Only support fp8 blockscale for now"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
w1_q
.
dtype
==
torch
.
float8_e4m3fn
assert
w2_q
.
dtype
==
torch
.
float8_e4m3fn
assert
a
.
shape
[
1
]
==
w1_q
.
shape
[
1
],
"Hidden size mismatch w1"
assert
w1_q
.
shape
[
2
]
==
w2_q
.
shape
[
1
]
*
2
,
"Hidden size mismatch w2"
assert
w1_q
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"Expert number mismatch"
assert
w1_q
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"Weights expert number mismatch"
assert
w1_q
.
shape
[
0
]
==
w1_scale
.
shape
[
0
],
"w1 scales expert number mismatch"
assert
w1_q
.
shape
[
0
]
==
w2_scale
.
shape
[
0
],
"w2 scales expert number mismatch"
assert
a
.
dtype
in
[
torch
.
half
,
torch
.
bfloat16
],
"Invalid output dtype"
if
is_cuda
:
from
sglang.srt.layers.quantization.fp8_kernel
import
(
sglang_per_token_group_quant_fp8
,
)
out_dtype
=
a
.
dtype
num_experts
=
w1_q
.
size
(
0
)
m
=
a
.
size
(
0
)
k
=
w1_q
.
size
(
1
)
n
=
w2_q
.
size
(
1
)
topk
=
topk_ids
.
size
(
1
)
a_q
,
a1_scale
=
sglang_per_token_group_quant_fp8
(
a
,
128
)
device
=
a_q
.
device
a_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
c_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
prepare_moe_input
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
a_map
,
c_map
,
num_experts
,
n
,
k
,
)
rep_a_q
=
a_q
.
view
(
dtype
=
torch
.
uint8
)[
a_map
].
view
(
dtype
=
a_q
.
dtype
)
rep_a1_scales
=
a1_scale
[
a_map
]
c1
=
torch
.
empty
((
m
*
topk
,
n
*
2
),
device
=
device
,
dtype
=
out_dtype
)
c2
=
torch
.
empty
((
m
*
topk
,
k
),
device
=
device
,
dtype
=
out_dtype
)
a_sf_layout
=
torch
.
empty
((
num_experts
,
5
),
device
=
device
,
dtype
=
torch
.
int
)
w_sf_layout
=
torch
.
empty
((
num_experts
,
5
),
device
=
device
,
dtype
=
torch
.
int
)
fp8_blockwise_scaled_grouped_mm
(
c1
,
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
rep_a_q
,
w1_q
,
rep_a1_scales
,
w1_scale
,
a1_strides
,
a1_strides
,
c1_strides
,
a_sf_layout
,
w_sf_layout
,
problem_sizes1
,
expert_offsets
[:
-
1
],
workspace
,
)
intermediate
=
torch
.
empty
((
m
*
topk
,
n
),
device
=
device
,
dtype
=
out_dtype
)
silu_and_mul
(
c1
,
intermediate
)
intemediate_q
,
a2_scale
=
sglang_per_token_group_quant_fp8
(
intermediate
,
128
)
fp8_blockwise_scaled_grouped_mm
(
c2
,
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
intemediate_q
,
w2_q
,
a2_scale
,
w2_scale
,
a2_strides
,
a2_strides
,
c2_strides
,
a_sf_layout
,
w_sf_layout
,
problem_sizes2
,
expert_offsets
[:
-
1
],
workspace
,
)
return
(
c2
[
c_map
].
view
(
m
,
topk
,
k
)
*
topk_weights
.
view
(
m
,
topk
,
1
).
to
(
out_dtype
)
).
sum
(
dim
=
1
)
python/sglang/srt/layers/quantization/fp8.py
View file @
6fc93575
...
@@ -52,6 +52,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
...
@@ -52,6 +52,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
apply_w8a8_block_fp8_linear
,
apply_w8a8_block_fp8_linear
,
cutlass_fp8_supported
,
cutlass_fp8_supported
,
input_to_float8
,
input_to_float8
,
is_sm100_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
normalize_e4m3fn_to_e4m3fnuz
,
)
)
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
...
@@ -470,6 +471,7 @@ class Fp8MoEMethod:
...
@@ -470,6 +471,7 @@ class Fp8MoEMethod:
def
__init__
(
self
,
quant_config
):
def
__init__
(
self
,
quant_config
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -568,6 +570,63 @@ class Fp8MoEMethod:
...
@@ -568,6 +570,63 @@ class Fp8MoEMethod:
layer
.
register_parameter
(
"w13_weight_scale_inv"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w13_weight_scale_inv"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale_inv"
,
w2_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale_inv"
,
w2_weight_scale
)
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
if
(
get_bool_env_var
(
"CUTLASS_MOE"
)
and
self
.
cutlass_fp8_supported
and
is_sm100_supported
()
):
self
.
ab_strides1
=
torch
.
full
(
(
num_experts
,),
hidden_size
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int64
,
)
self
.
c_strides1
=
torch
.
full
(
(
num_experts
,),
2
*
intermediate_size
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int64
,
)
self
.
ab_strides2
=
torch
.
full
(
(
num_experts
,),
intermediate_size
,
device
=
w2_weight
.
device
,
dtype
=
torch
.
int64
,
)
self
.
c_strides2
=
torch
.
full
(
(
num_experts
,),
hidden_size
,
device
=
w2_weight
.
device
,
dtype
=
torch
.
int64
,
)
self
.
workspace
=
torch
.
empty
(
90000
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
uint8
)
self
.
a_ptr
=
torch
.
empty
(
num_experts
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int64
)
self
.
b_ptr
=
torch
.
empty
(
num_experts
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int64
)
self
.
out_ptr
=
torch
.
empty
(
num_experts
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int64
)
self
.
a_scales_ptr
=
torch
.
empty
(
num_experts
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int64
)
self
.
b_scales_ptr
=
torch
.
empty
(
num_experts
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int64
)
self
.
expert_offsets
=
torch
.
empty
(
num_experts
+
1
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int32
)
self
.
problem_sizes1
=
torch
.
empty
(
num_experts
,
3
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int32
)
self
.
problem_sizes2
=
torch
.
empty
(
num_experts
,
3
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int32
)
else
:
else
:
# Allocate 2 scales for w1 and w3 respectively.
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
# They will be combined to a single scale after weight loading.
...
@@ -913,6 +972,37 @@ class Fp8MoEMethod:
...
@@ -913,6 +972,37 @@ class Fp8MoEMethod:
if
ret
is
not
None
:
if
ret
is
not
None
:
return
ret
return
ret
if
(
get_bool_env_var
(
"CUTLASS_MOE"
)
and
self
.
cutlass_fp8_supported
and
self
.
block_quant
and
is_sm100_supported
()
):
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_fused_experts
return
cutlass_fused_experts
(
x
,
layer
.
w13_weight
.
transpose
(
1
,
2
),
layer
.
w2_weight
.
transpose
(
1
,
2
),
layer
.
w13_weight_scale_inv
.
transpose
(
1
,
2
),
layer
.
w2_weight_scale_inv
.
transpose
(
1
,
2
),
topk_weights
,
topk_ids
,
self
.
ab_strides1
,
self
.
c_strides1
,
self
.
ab_strides2
,
self
.
c_strides2
,
self
.
workspace
,
self
.
a_ptr
,
self
.
b_ptr
,
self
.
out_ptr
,
self
.
a_scales_ptr
,
self
.
b_scales_ptr
,
self
.
expert_offsets
,
self
.
problem_sizes1
,
self
.
problem_sizes2
,
use_fp8_blockscale
=
True
,
)
# Expert fusion with FP8 quantization
# Expert fusion with FP8 quantization
return
fused_experts
(
return
fused_experts
(
x
,
x
,
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
6fc93575
...
@@ -80,6 +80,12 @@ def cutlass_fp8_supported():
...
@@ -80,6 +80,12 @@ def cutlass_fp8_supported():
return
False
return
False
def
is_sm100_supported
(
device
=
None
)
->
bool
:
return
(
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
10
)
and
(
torch
.
version
.
cuda
>=
"12.8"
)
def
normalize_e4m3fn_to_e4m3fnuz
(
def
normalize_e4m3fn_to_e4m3fnuz
(
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
...
...
python/sglang/test/test_cutlass_moe.py
0 → 100755
View file @
6fc93575
import
argparse
import
time
import
torch
import
triton
# Added import
import
triton.testing
# Added import
from
transformers
import
AutoConfig
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_fused_experts
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
def
get_model_config
(
tp_size
:
int
):
config
=
AutoConfig
.
from_pretrained
(
"deepseek-ai/deepseek-R1"
,
trust_remote_code
=
True
)
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
return
{
"num_experts"
:
E
,
"topk"
:
topk
,
"hidden_size"
:
config
.
hidden_size
,
"shard_intermediate_size"
:
shard_intermediate_size
,
"dtype"
:
config
.
torch_dtype
,
"block_shape"
:
config
.
quantization_config
[
"weight_block_size"
],
}
def
to_fp8
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Converts tensor to FP8 E4M3, scaling values to fit the range."""
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
# Calculate max absolute value safely
max_val
=
torch
.
max
(
torch
.
abs
(
tensor
))
# Avoid division by zero if tensor is all zeros
if
max_val
==
0
:
scale_factor
=
1.0
else
:
# Scale factor to bring the max value to finfo.max
scale_factor
=
finfo
.
max
/
max_val
# Apply scaling
scaled_tensor
=
tensor
*
scale_factor
# Clamp and convert
fp8_tensor
=
scaled_tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
).
to
(
dtype
=
torch
.
float8_e4m3fn
)
return
fp8_tensor
def
run_test
(
tp_size
,
batch_size
,
model_config
,
check
=
False
):
print
(
f
"
\n
--- Batch Size:
{
batch_size
}
---"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
42
)
# For reproducible random numbers
E
=
model_config
[
"num_experts"
]
topk
=
model_config
[
"topk"
]
H
=
model_config
[
"hidden_size"
]
I
=
model_config
[
"shard_intermediate_size"
]
block_shape
=
model_config
[
"block_shape"
]
# Tuple (BLOCK_N, BLOCK_K)
dtype
=
model_config
[
"dtype"
]
# e.g., torch.bfloat16
print
(
f
"Config: E=
{
E
}
, topk=
{
topk
}
, H=
{
H
}
, I_shard=
{
I
}
, dtype=
{
dtype
}
, block_shape=
{
block_shape
}
"
)
# --- Input Data ---
# Use bf16/fp16 for input activation based on model config
x
=
torch
.
randn
((
batch_size
,
H
),
device
=
"cuda"
,
dtype
=
dtype
)
*
0.0001
# --- Weights (Generate in higher precision, then convert to FP8) ---
# Generate weights suitable for FP8 conversion (e.g., scaled appropriately)
w1_hp
=
(
torch
.
randn
((
E
,
I
,
H
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
0.00001
+
0.00001
)
w2_hp
=
(
torch
.
randn
((
E
,
H
,
I
//
2
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
0.00001
+
0.00001
)
w1
=
to_fp8
(
w1_hp
)
w2
=
to_fp8
(
w2_hp
)
# --- Scales for FP8 Weights ---
block_n
,
block_k
=
block_shape
# Calculate number of blocks needed
w1_blocks_dim1
=
(
I
+
block_n
-
1
)
//
block_n
w1_blocks_dim2
=
(
H
+
block_k
-
1
)
//
block_k
w2_blocks_dim1
=
(
H
+
block_n
-
1
)
//
block_n
w2_blocks_dim2
=
(
I
//
2
+
block_k
-
1
)
//
block_k
# Scales are typically float32 or float16/bfloat16
scale_dtype
=
torch
.
float32
# Or dtype if scales match model dtype
w1_scale
=
torch
.
full
(
(
E
,
w1_blocks_dim1
,
w1_blocks_dim2
),
1
,
device
=
"cuda"
,
dtype
=
scale_dtype
)
# Avoid zero scales
w2_scale
=
torch
.
full
(
(
E
,
w2_blocks_dim1
,
w2_blocks_dim2
),
1
,
device
=
"cuda"
,
dtype
=
scale_dtype
)
# Avoid zero scales
# --- Routing Information ---
topk_weights
=
torch
.
softmax
(
torch
.
rand
(
batch_size
,
topk
,
device
=
"cuda"
,
dtype
=
dtype
),
dim
=-
1
)
topk_ids
=
torch
.
randint
(
0
,
E
,
(
batch_size
,
topk
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
a1_strides
=
torch
.
full
((
E
,),
H
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
c1_strides
=
torch
.
full
((
E
,),
I
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
a2_strides
=
torch
.
full
((
E
,),
I
//
2
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
c2_strides
=
torch
.
full
((
E
,),
H
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
workspace
=
torch
.
empty
(
(
7182
*
1024
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
# Allocate sufficient workspace
# Pointer arrays (often filled by the kernel or a prep step, but needed as args)
a_ptrs
=
torch
.
empty
((
E
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
b_ptrs
=
torch
.
empty
((
E
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
out_ptrs
=
torch
.
empty
((
E
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
a_scales_ptrs
=
torch
.
empty
((
E
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
b_scales_ptrs
=
torch
.
empty
((
E
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
expert_offsets
=
torch
.
empty
((
E
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
problem_sizes1
=
torch
.
empty
((
E
,
3
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
problem_sizes2
=
torch
.
empty
((
E
,
3
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# --- Lambdas for Benchmarking ---
cutlass_lambda
=
lambda
:
cutlass_fused_experts
(
x
,
w1
.
transpose
(
1
,
2
),
# Transposed
w2
.
transpose
(
1
,
2
),
# Transposed
w1_scale
.
transpose
(
1
,
2
),
w2_scale
.
transpose
(
1
,
2
),
topk_weights
,
topk_ids
,
a1_strides
,
c1_strides
,
a2_strides
,
c2_strides
,
workspace
,
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
)
# Note: Triton expects non-transposed weights
triton_lambda
=
lambda
:
fused_experts
(
x
,
w1
,
w2
,
topk_weights
,
topk_ids
,
inplace
=
False
,
# Use False for benchmarking to avoid side effects if run multiple times
activation
=
"silu"
,
# Assuming SiLU activation common in MoEs
use_fp8_w8a8
=
True
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
block_shape
=
block_shape
,
)
# --- Warmup ---
print
(
"Warming up..."
)
for
_
in
range
(
10
):
_
=
cutlass_lambda
()
_
=
triton_lambda
()
torch
.
cuda
.
synchronize
()
# --- Benchmarking ---
quantiles
=
[
0.5
,
0.2
,
0.8
]
print
(
f
"Benchmarking Cutlass fused_experts..."
)
cutlass_ms
,
cutlass_min
,
cutlass_max
=
triton
.
testing
.
do_bench_cudagraph
(
cutlass_lambda
,
rep
=
1000
,
quantiles
=
quantiles
)
print
(
f
"Benchmarking Triton fused_experts..."
)
triton_ms
,
triton_min
,
triton_max
=
triton
.
testing
.
do_bench_cudagraph
(
triton_lambda
,
rep
=
1000
,
quantiles
=
quantiles
)
print
(
f
"Cutlass fused_experts time:
{
cutlass_ms
:.
3
f
}
ms (median) [
{
cutlass_min
:.
3
f
}
-
{
cutlass_max
:.
3
f
}
]"
)
print
(
f
"Triton fused_experts time:
{
triton_ms
:.
3
f
}
ms (median) [
{
triton_min
:.
3
f
}
-
{
triton_max
:.
3
f
}
]"
)
# --- Correctness Check ---
if
check
:
print
(
"Running correctness check..."
)
with
torch
.
no_grad
():
# Run CUTLASS version (requires transposed weights)
y_cutlass
=
cutlass_fused_experts
(
x
,
w1
.
transpose
(
1
,
2
),
# Transposed
w2
.
transpose
(
1
,
2
),
# Transposed
w1_scale
.
transpose
(
1
,
2
),
w2_scale
.
transpose
(
1
,
2
),
topk_weights
,
topk_ids
,
a1_strides
,
c1_strides
,
a2_strides
,
c2_strides
,
workspace
,
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
)
# Run Triton version (requires original shape weights, use inplace=False)
y_triton
=
fused_experts
(
x
,
w1
,
# Original shape
w2
,
# Original shape
topk_weights
,
topk_ids
,
inplace
=
False
,
# Important: Use False to get output tensor
activation
=
"silu"
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
block_shape
=
block_shape
,
)
# Ensure outputs are same dtype for comparison
y_cutlass
=
y_cutlass
.
to
(
dtype
)
y_triton
=
y_triton
.
to
(
dtype
)
abs_error
=
torch
.
abs
(
y_cutlass
-
y_triton
)
rel_error
=
abs_error
/
torch
.
clamp
(
torch
.
abs
(
y_triton
),
min
=
1e-2
)
max_abs_err
=
abs_error
.
max
().
item
()
max_rel_err
=
rel_error
.
max
().
item
()
print
(
"y_cutlass:"
,
y_cutlass
[:,
:
10
])
print
(
"y_triton:"
,
y_triton
[:,
:
10
])
print
(
f
"Max absolute error:
{
max_abs_err
:.
6
f
}
"
)
print
(
f
"Max relative error:
{
max_rel_err
:.
6
f
}
"
)
# Tolerance might need adjustment based on FP8 specifics and kernel differences
# FP8 comparisons often require higher tolerance than FP16/BF16
assert
max_rel_err
<
5e-1
,
f
"Relative error too high!
{
max_rel_err
}
"
print
(
"Correctness check passed."
)
def
main
(
tp_size
=
8
,
batch_sizes
=
[
1
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
],
check
=
False
):
model_config
=
get_model_config
(
tp_size
)
print
(
"Model Config:"
,
model_config
)
for
batch_size
in
batch_sizes
:
run_test
(
tp_size
,
batch_size
,
model_config
,
check
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--tp-size"
,
type
=
int
,
default
=
8
,
help
=
"Tensor Parallel size"
)
parser
.
add_argument
(
"--batch-sizes"
,
type
=
int
,
nargs
=
"+"
,
default
=
[
1
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
],
# Adjusted default
help
=
"List of batch sizes to test"
,
)
parser
.
add_argument
(
"--check"
,
action
=
"store_true"
,
help
=
"Enable check mode"
)
args
=
parser
.
parse_args
()
print
(
f
"Running benchmarks with TP size:
{
args
.
tp_size
}
"
)
print
(
f
"Testing batch sizes:
{
args
.
batch_sizes
}
"
)
main
(
tp_size
=
args
.
tp_size
,
batch_sizes
=
args
.
batch_sizes
,
check
=
args
.
check
)
sgl-kernel/CMakeLists.txt
100755 → 100644
View file @
6fc93575
...
@@ -207,6 +207,7 @@ set(SOURCES
...
@@ -207,6 +207,7 @@ set(SOURCES
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
"csrc/moe/fp8_blockwise_moe_kernel.cu"
"csrc/moe/fp8_blockwise_moe_kernel.cu"
"csrc/moe/prepare_moe_input.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/speculative/packbit.cu"
"csrc/speculative/packbit.cu"
...
...
sgl-kernel/csrc/common_extension.cc
View file @
6fc93575
...
@@ -151,11 +151,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -151,11 +151,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"(Tensor[])"
);
"(Tensor[])"
);
m
.
impl
(
"moe_fused_gate"
,
torch
::
kCUDA
,
&
moe_fused_gate
);
m
.
impl
(
"moe_fused_gate"
,
torch
::
kCUDA
,
&
moe_fused_gate
);
m
.
def
(
m
.
def
(
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor "
"a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
"stride_a, Tensor stride_b, Tensor stride_c, Tensor layout_sfa, Tensor layout_sfb, Tensor problem_sizes, Tensor "
"stride_a, Tensor stride_b, Tensor stride_c, Tensor layout_sfa, Tensor layout_sfb, Tensor problem_sizes, Tensor "
"expert_offsets) -> ()"
);
"expert_offsets
, Tensor workspace
) -> ()"
);
m
.
impl
(
"fp8_blockwise_scaled_grouped_mm"
,
torch
::
kCUDA
,
&
fp8_blockwise_scaled_grouped_mm
);
m
.
impl
(
"fp8_blockwise_scaled_grouped_mm"
,
torch
::
kCUDA
,
&
fp8_blockwise_scaled_grouped_mm
);
m
.
def
(
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor problem_sizes1, Tensor problem_sizes2, Tensor "
"input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> ()"
);
m
.
impl
(
"prepare_moe_input"
,
torch
::
kCUDA
,
&
prepare_moe_input
);
/*
/*
* From csrc/speculative
* From csrc/speculative
*/
*/
...
...
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
View file @
6fc93575
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cutlass/arch/arch.h>
#include <cutlass/arch/arch.h>
#include <torch/all.h>
#include <torch/all.h>
...
@@ -49,23 +51,16 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
...
@@ -49,23 +51,16 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
using
ElementC
=
OutType
;
using
ElementC
=
OutType
;
using
ElementD
=
ElementC
;
using
ElementD
=
ElementC
;
using
ElementAccumulator
=
float
;
using
ElementAccumulator
=
float
;
// Layout definitions
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutC
=
LayoutD
;
using
LayoutC
=
LayoutD
;
// Alignment constraints
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
// Architecture definitions
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
// For fp8 block scale.
// using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN,
// ScaleGranularityK, cute::UMMA::Major::K, cute::UMMA::Major::K>; using LayoutSFA =
// decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
ArchTag
,
OperatorClass
,
OperatorClass
,
...
@@ -124,9 +119,8 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
...
@@ -124,9 +119,8 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
cutlass
::
KernelHardwareInfo
hw_info
;
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
0
;
hw_info
.
device_id
=
0
;
hw_info
.
sm_count
=
1
;
// sm_count is the number of SMs on the current device, since we only support SM100 blackwell, so we set it to 148
// Currently, we are only able to do broadcast on either all or none a_scales
hw_info
.
sm_count
=
148
;
// and on either all or none b_scales
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
{},
nullptr
,
nullptr
,
...
@@ -134,9 +128,7 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
...
@@ -134,9 +128,7 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
static_cast
<
ElementD
**>
(
out_ptrs
.
data_ptr
()),
static_cast
<
ElementD
**>
(
out_ptrs
.
data_ptr
()),
static_cast
<
StrideC
*>
(
stride_c
.
data_ptr
())};
static_cast
<
StrideC
*>
(
stride_c
.
data_ptr
())};
// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape
*
problem_sizes_as_shapes
=
static_cast
<
UnderlyingProblemShape
*>
(
problem_sizes
.
data_ptr
());
UnderlyingProblemShape
*
problem_sizes_as_shapes
=
static_cast
<
UnderlyingProblemShape
*>
(
problem_sizes
.
data_ptr
());
// Use prob_shape in the GEMM arguments
typename
GemmKernel
::
Arguments
args
{
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGrouped
,
cutlass
::
gemm
::
GemmUniversalMode
::
kGrouped
,
{
num_experts
,
problem_sizes_as_shapes
,
nullptr
},
{
num_experts
,
problem_sizes_as_shapes
,
nullptr
},
...
@@ -144,21 +136,27 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
...
@@ -144,21 +136,27 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
epilogue_args
,
epilogue_args
,
hw_info
};
hw_info
};
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
a_ptrs
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a_ptrs
.
get_device
());
auto
can_implement_status
=
gemm_op
.
can_implement
(
args
);
auto
can_implement_status
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to implement GEMM"
);
TORCH_CHECK
(
can_implement_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to implement GEMM"
);
// Run the GEMM
auto
status
=
gemm_op
.
initialize
(
args
,
workspace
.
data_ptr
(),
stream
);
auto
status
=
gemm_op
.
initialize
(
args
,
workspace
.
data_ptr
());
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to initialize GEMM"
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to initialize GEMM"
);
status
=
gemm_op
.
run
();
status
=
gemm_op
.
run
(
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run GEMM"
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run GEMM"
);
}
}
template
<
typename
OutType
>
template
<
typename
OutType
>
void
sm100_fp8_blockwise_group_mm_dispatch_shape
(
void
sm100_fp8_blockwise_group_mm_dispatch_shape
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output
,
torch
::
Tensor
&
a_ptrs
,
torch
::
Tensor
&
b_ptrs
,
torch
::
Tensor
&
out_ptrs
,
torch
::
Tensor
&
a_scales_ptrs
,
torch
::
Tensor
&
b_scales_ptrs
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_a
,
...
@@ -169,11 +167,23 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
...
@@ -169,11 +167,23 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
const
torch
::
Tensor
&
layout_sfa
,
const
torch
::
Tensor
&
layout_sfa
,
const
torch
::
Tensor
&
layout_sfb
,
const
torch
::
Tensor
&
layout_sfb
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
)
{
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
workspace
)
{
// Check the first matrix size to decide on the configuration
// Check the first matrix size to decide on the configuration
// Assuming all matrices in the group have similar size characteristics
// Assuming all matrices in the group have similar size characteristics
// bool use_small_config = a[0].size(0) <= 128;
// bool use_small_config = a[0].size(0) <= 128;
struct
MMALargeConfig
{
struct
MmaConfig1
{
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
MmaTileShape
=
Shape
<
_128
,
_32
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
// Layout type for SFB matrix operand
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecialized1Sm
;
using
ScaleConfig
=
cutlass
::
detail
::
Sm100BlockwiseScaleConfig
<
128
,
1
,
128
,
cute
::
UMMA
::
Major
::
K
,
cute
::
UMMA
::
Major
::
K
>
;
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
};
struct
MmaConfig2
{
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
// Layout type for SFB matrix operand
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
// Layout type for SFB matrix operand
...
@@ -184,35 +194,28 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
...
@@ -184,35 +194,28 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
};
};
struct
MmaConfig3
{
struct
MMASmallConfig
{
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
MmaTileShape
=
Shape
<
_
128
,
_1
6
,
_128
>
;
using
MmaTileShape
=
Shape
<
_
64
,
_1
28
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
// Layout type for SFB matrix operand
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
// Layout type for SFB matrix operand
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecialized1Sm
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecialized1Sm
;
using
ScaleConfig
=
using
ScaleConfig
=
cutlass
::
detail
::
Sm100BlockwiseScaleConfig
<
1
28
,
1
,
128
,
cute
::
UMMA
::
Major
::
K
,
cute
::
UMMA
::
Major
::
K
>
;
cutlass
::
detail
::
Sm100BlockwiseScaleConfig
<
1
,
1
28
,
128
,
cute
::
UMMA
::
Major
::
K
,
cute
::
UMMA
::
Major
::
K
>
;
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
};
};
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
torch
::
TensorOptions
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
());
torch
::
TensorOptions
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
());
torch
::
Tensor
a_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
out_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
a_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
problem_sizes_transpose
=
torch
::
empty
(
num_experts
*
3
,
options_int
);
torch
::
Tensor
problem_sizes_transpose
=
torch
::
empty
(
num_experts
*
3
,
options_int
);
torch
::
Tensor
workspace
=
torch
::
empty
(
100
,
options_int
);
torch
::
Tensor
output_t
=
output
.
t
();
torch
::
Tensor
output_t
=
output
.
t
();
torch
::
Tensor
a_t
=
a
.
t
();
torch
::
Tensor
a_t
=
a
.
t
();
torch
::
Tensor
b_t
=
b
.
transpose
(
1
,
2
);
torch
::
Tensor
b_t
=
b
.
transpose
(
1
,
2
);
torch
::
Tensor
scales_a_t
=
scales_a
.
t
();
torch
::
Tensor
scales_a_t
=
scales_a
.
t
();
torch
::
Tensor
scales_b_t
=
scales_b
.
transpose
(
1
,
2
);
torch
::
Tensor
scales_b_t
=
scales_b
.
transpose
(
1
,
2
);
if
(
a
.
size
(
0
)
<=
512
)
{
if
(
a
.
size
(
0
)
<=
512
&&
a
.
size
(
1
)
>=
2048
)
{
run_get_group_gemm_starts
<
M
MASmall
Config
::
LayoutSFA
,
M
MASmall
Config
::
LayoutSFB
,
M
MASmall
Config
::
ScaleConfig
>
(
run_get_group_gemm_starts
<
M
ma
Config
1
::
LayoutSFA
,
M
ma
Config
1
::
LayoutSFB
,
M
ma
Config
1
::
ScaleConfig
>
(
expert_offsets
,
expert_offsets
,
a_ptrs
,
a_ptrs
,
b_ptrs
,
b_ptrs
,
...
@@ -229,7 +232,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
...
@@ -229,7 +232,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
problem_sizes
,
problem_sizes
,
problem_sizes_transpose
,
problem_sizes_transpose
,
true
);
true
);
launch_sm100_fp8_blockwise_scaled_group_mm
<
OutType
,
M
MASmall
Config
,
cutlass
::
layout
::
ColumnMajor
>
(
launch_sm100_fp8_blockwise_scaled_group_mm
<
OutType
,
M
ma
Config
1
,
cutlass
::
layout
::
ColumnMajor
>
(
out_ptrs
,
out_ptrs
,
a_ptrs
,
a_ptrs
,
b_ptrs
,
b_ptrs
,
...
@@ -244,8 +247,39 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
...
@@ -244,8 +247,39 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
expert_offsets
,
expert_offsets
,
workspace
);
workspace
);
output
=
output_t
.
t
();
output
=
output_t
.
t
();
}
else
if
(
a
.
size
(
0
)
>
512
&&
a
.
size
(
1
)
>=
2048
)
{
run_get_group_gemm_starts
<
MmaConfig2
::
LayoutSFA
,
MmaConfig2
::
LayoutSFB
,
MmaConfig2
::
ScaleConfig
>
(
expert_offsets
,
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
a
,
b
,
output
,
scales_a
,
scales_b
,
layout_sfa
,
layout_sfb
,
problem_sizes
,
problem_sizes_transpose
);
launch_sm100_fp8_blockwise_scaled_group_mm
<
OutType
,
MmaConfig2
,
cutlass
::
layout
::
RowMajor
>
(
out_ptrs
,
a_ptrs
,
b_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
stride_a
,
stride_b
,
stride_c
,
layout_sfa
,
layout_sfb
,
problem_sizes
,
expert_offsets
,
workspace
);
}
else
{
}
else
{
run_get_group_gemm_starts
<
M
MALarge
Config
::
LayoutSFA
,
M
MALarge
Config
::
LayoutSFB
,
M
MALarge
Config
::
ScaleConfig
>
(
run_get_group_gemm_starts
<
M
ma
Config
3
::
LayoutSFA
,
M
ma
Config
3
::
LayoutSFB
,
M
ma
Config
3
::
ScaleConfig
>
(
expert_offsets
,
expert_offsets
,
a_ptrs
,
a_ptrs
,
b_ptrs
,
b_ptrs
,
...
@@ -261,7 +295,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
...
@@ -261,7 +295,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
layout_sfb
,
layout_sfb
,
problem_sizes
,
problem_sizes
,
problem_sizes_transpose
);
problem_sizes_transpose
);
launch_sm100_fp8_blockwise_scaled_group_mm
<
OutType
,
M
MALarge
Config
,
cutlass
::
layout
::
RowMajor
>
(
launch_sm100_fp8_blockwise_scaled_group_mm
<
OutType
,
M
ma
Config
3
,
cutlass
::
layout
::
RowMajor
>
(
out_ptrs
,
out_ptrs
,
a_ptrs
,
a_ptrs
,
b_ptrs
,
b_ptrs
,
...
@@ -312,6 +346,11 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
...
@@ -312,6 +346,11 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
*/
*/
void
fp8_blockwise_scaled_grouped_mm
(
void
fp8_blockwise_scaled_grouped_mm
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output
,
torch
::
Tensor
&
a_ptrs
,
torch
::
Tensor
&
b_ptrs
,
torch
::
Tensor
&
out_ptrs
,
torch
::
Tensor
&
a_scales_ptrs
,
torch
::
Tensor
&
b_scales_ptrs
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_a
,
...
@@ -322,7 +361,8 @@ void fp8_blockwise_scaled_grouped_mm(
...
@@ -322,7 +361,8 @@ void fp8_blockwise_scaled_grouped_mm(
const
torch
::
Tensor
&
layout_sfa
,
const
torch
::
Tensor
&
layout_sfa
,
const
torch
::
Tensor
&
layout_sfb
,
const
torch
::
Tensor
&
layout_sfb
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
)
{
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
workspace
)
{
TORCH_CHECK
(
problem_sizes
.
dim
()
==
2
,
"problem_sizes must be 2D tensor"
);
TORCH_CHECK
(
problem_sizes
.
dim
()
==
2
,
"problem_sizes must be 2D tensor"
);
TORCH_CHECK
(
problem_sizes
.
size
(
1
)
==
3
,
"problem_sizes must have shape (num_experts, 3)"
);
TORCH_CHECK
(
problem_sizes
.
size
(
1
)
==
3
,
"problem_sizes must have shape (num_experts, 3)"
);
TORCH_CHECK
(
TORCH_CHECK
(
...
@@ -342,6 +382,29 @@ void fp8_blockwise_scaled_grouped_mm(
...
@@ -342,6 +382,29 @@ void fp8_blockwise_scaled_grouped_mm(
TORCH_CHECK
(
layout_sfb
.
scalar_type
()
==
torch
::
kInt32
,
"layout_sfb must be int32"
);
TORCH_CHECK
(
layout_sfb
.
scalar_type
()
==
torch
::
kInt32
,
"layout_sfb must be int32"
);
TORCH_CHECK
(
expert_offsets
.
scalar_type
()
==
torch
::
kInt32
,
"expert_offsets must be int32"
);
TORCH_CHECK
(
expert_offsets
.
scalar_type
()
==
torch
::
kInt32
,
"expert_offsets must be int32"
);
TORCH_CHECK
(
output
.
dim
()
==
2
,
"output must be 2D tensor"
);
TORCH_CHECK
(
a
.
dim
()
==
2
,
"a must be 2D tensor"
);
TORCH_CHECK
(
b
.
dim
()
==
3
,
"b must be 3D tensor"
);
TORCH_CHECK
(
scales_a
.
dim
()
==
2
,
"scales_a must be 2D tensor"
);
TORCH_CHECK
(
scales_b
.
dim
()
==
3
,
"scales_b must be 3D tensor"
);
TORCH_CHECK
(
stride_a
.
dim
()
==
1
,
"stride_a must be 1D tensor"
);
TORCH_CHECK
(
stride_b
.
dim
()
==
1
,
"stride_b must be 1D tensor"
);
TORCH_CHECK
(
stride_c
.
dim
()
==
1
,
"stride_c must be 1D tensor"
);
TORCH_CHECK
(
layout_sfa
.
dim
()
==
2
,
"layout_sfa must be 1D tensor"
);
TORCH_CHECK
(
layout_sfb
.
dim
()
==
2
,
"layout_sfb must be 1D tensor"
);
TORCH_CHECK
(
a_ptrs
.
dim
()
==
1
,
"a_ptrs must be 1D tensor"
);
TORCH_CHECK
(
b_ptrs
.
dim
()
==
1
,
"b_ptrs must be 1D tensor"
);
TORCH_CHECK
(
out_ptrs
.
dim
()
==
1
,
"out_ptrs must be 1D tensor"
);
TORCH_CHECK
(
a_scales_ptrs
.
dim
()
==
1
,
"a_scales_ptrs must be 1D tensor"
);
TORCH_CHECK
(
b_scales_ptrs
.
dim
()
==
1
,
"b_scales_ptrs must be 1D tensor"
);
TORCH_CHECK
(
problem_sizes
.
dim
()
==
2
,
"problem_sizes must be 2D tensor"
);
TORCH_CHECK
(
problem_sizes
.
size
(
1
)
==
3
,
"problem_sizes must have shape (num_experts, 3)"
);
TORCH_CHECK
(
problem_sizes
.
size
(
0
)
==
expert_offsets
.
size
(
0
),
"Number of experts in problem_sizes must match expert_offsets"
);
TORCH_CHECK
(
problem_sizes
.
dtype
()
==
torch
::
kInt32
,
"problem_sizes must be int32"
);
TORCH_CHECK
(
expert_offsets
.
dim
()
==
1
,
"expert_offsets must be 1D tensor"
);
TORCH_CHECK
(
workspace
.
dim
()
==
1
,
"workspace must be 1D tensor"
);
bool
can_implement
=
false
;
bool
can_implement
=
false
;
auto
sm_version
=
getSMVersion
();
auto
sm_version
=
getSMVersion
();
...
@@ -351,6 +414,11 @@ void fp8_blockwise_scaled_grouped_mm(
...
@@ -351,6 +414,11 @@ void fp8_blockwise_scaled_grouped_mm(
if
(
output
.
scalar_type
()
==
torch
::
kBFloat16
)
{
if
(
output
.
scalar_type
()
==
torch
::
kBFloat16
)
{
sm100_fp8_blockwise_group_mm_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
sm100_fp8_blockwise_group_mm_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
output
,
output
,
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
a
,
a
,
b
,
b
,
scales_a
,
scales_a
,
...
@@ -361,10 +429,16 @@ void fp8_blockwise_scaled_grouped_mm(
...
@@ -361,10 +429,16 @@ void fp8_blockwise_scaled_grouped_mm(
layout_sfa
,
layout_sfa
,
layout_sfb
,
layout_sfb
,
problem_sizes
,
problem_sizes
,
expert_offsets
);
expert_offsets
,
workspace
);
}
else
{
}
else
{
sm100_fp8_blockwise_group_mm_dispatch_shape
<
cutlass
::
half_t
>
(
sm100_fp8_blockwise_group_mm_dispatch_shape
<
cutlass
::
half_t
>
(
output
,
output
,
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
a
,
a
,
b
,
b
,
scales_a
,
scales_a
,
...
@@ -375,7 +449,8 @@ void fp8_blockwise_scaled_grouped_mm(
...
@@ -375,7 +449,8 @@ void fp8_blockwise_scaled_grouped_mm(
layout_sfa
,
layout_sfa
,
layout_sfb
,
layout_sfb
,
problem_sizes
,
problem_sizes
,
expert_offsets
);
expert_offsets
,
workspace
);
}
}
can_implement
=
true
;
can_implement
=
true
;
}
}
...
...
sgl-kernel/csrc/moe/prepare_moe_input.cu
0 → 100644
View file @
6fc93575
#include <c10/cuda/CUDAGuard.h>
#include <cudaTypedefs.h>
#include <torch/all.h>
#include <iostream>
constexpr
uint64_t
THREADS_PER_EXPERT
=
512
;
__global__
void
compute_problem_sizes
(
const
int
*
__restrict__
topk_ids
,
int32_t
*
problem_sizes1
,
int32_t
*
problem_sizes2
,
int32_t
*
atomic_buffer
,
const
int
topk_length
,
const
int
n
,
const
int
k
)
{
int
expert_id
=
blockIdx
.
x
;
int
occurrences
=
0
;
for
(
int
i
=
threadIdx
.
x
;
i
<
topk_length
;
i
+=
THREADS_PER_EXPERT
)
{
occurrences
+=
(
topk_ids
[
i
]
==
expert_id
);
}
atomicAdd
(
&
atomic_buffer
[
expert_id
],
occurrences
);
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
int
final_occurrences
=
atomic_buffer
[
expert_id
];
problem_sizes1
[
expert_id
*
3
]
=
final_occurrences
;
problem_sizes1
[
expert_id
*
3
+
1
]
=
2
*
n
;
problem_sizes1
[
expert_id
*
3
+
2
]
=
k
;
problem_sizes2
[
expert_id
*
3
]
=
final_occurrences
;
problem_sizes2
[
expert_id
*
3
+
1
]
=
k
;
problem_sizes2
[
expert_id
*
3
+
2
]
=
n
;
}
}
__global__
void
compute_expert_offsets
(
const
int32_t
*
__restrict__
problem_sizes1
,
int32_t
*
expert_offsets
,
int32_t
*
atomic_buffer
,
const
int
num_experts
)
{
int32_t
tot_offset
=
0
;
expert_offsets
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
atomic_buffer
[
i
]
=
tot_offset
;
tot_offset
+=
problem_sizes1
[
i
*
3
];
expert_offsets
[
i
+
1
]
=
tot_offset
;
}
}
__global__
void
compute_arg_sorts
(
const
int
*
__restrict__
topk_ids
,
int32_t
*
input_permutation
,
int32_t
*
output_permutation
,
int32_t
*
atomic_buffer
,
const
int
topk_length
,
const
int
topk
)
{
int
expert_id
=
blockIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
topk_length
;
i
+=
THREADS_PER_EXPERT
)
{
if
(
topk_ids
[
i
]
==
expert_id
)
{
int
start
=
atomicAdd
(
&
atomic_buffer
[
expert_id
],
1
);
input_permutation
[
start
]
=
i
/
topk
;
output_permutation
[
i
]
=
start
;
}
}
}
void
get_moe_prepare_input_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
topk_ids
.
device
().
index
());
auto
options_int32
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
topk_ids
.
device
());
torch
::
Tensor
atomic_buffer
=
torch
::
zeros
(
num_experts
,
options_int32
);
int
num_threads
=
min
(
THREADS_PER_EXPERT
,
topk_ids
.
numel
());
compute_problem_sizes
<<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes2
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
topk_ids
.
numel
(),
n
,
k
);
compute_expert_offsets
<<<
1
,
1
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
num_experts
);
compute_arg_sorts
<<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
int32_t
*>
(
input_permutation
.
data_ptr
()),
static_cast
<
int32_t
*>
(
output_permutation
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
topk_ids
.
numel
(),
topk_ids
.
size
(
1
));
}
void
prepare_moe_input
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
)
{
TORCH_CHECK
(
topk_ids
.
dtype
()
==
torch
::
kInt32
);
get_moe_prepare_input_caller
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
input_permutation
,
output_permutation
,
num_experts
,
n
,
k
);
return
;
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
6fc93575
...
@@ -211,6 +211,11 @@ std::vector<at::Tensor> moe_fused_gate(
...
@@ -211,6 +211,11 @@ std::vector<at::Tensor> moe_fused_gate(
void
fp8_blockwise_scaled_grouped_mm
(
void
fp8_blockwise_scaled_grouped_mm
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output
,
torch
::
Tensor
&
a_ptrs
,
torch
::
Tensor
&
b_ptrs
,
torch
::
Tensor
&
out_ptrs
,
torch
::
Tensor
&
a_scales_ptrs
,
torch
::
Tensor
&
b_scales_ptrs
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_a
,
...
@@ -221,7 +226,19 @@ void fp8_blockwise_scaled_grouped_mm(
...
@@ -221,7 +226,19 @@ void fp8_blockwise_scaled_grouped_mm(
const
torch
::
Tensor
&
layout_sfa
,
const
torch
::
Tensor
&
layout_sfa
,
const
torch
::
Tensor
&
layout_sfb
,
const
torch
::
Tensor
&
layout_sfb
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
);
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
workspace
);
void
prepare_moe_input
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
);
/*
/*
* From csrc/speculative
* From csrc/speculative
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
6fc93575
...
@@ -47,6 +47,7 @@ from sgl_kernel.moe import (
...
@@ -47,6 +47,7 @@ from sgl_kernel.moe import (
fp8_blockwise_scaled_grouped_mm
,
fp8_blockwise_scaled_grouped_mm
,
moe_align_block_size
,
moe_align_block_size
,
moe_fused_gate
,
moe_fused_gate
,
prepare_moe_input
,
topk_softmax
,
topk_softmax
,
)
)
from
sgl_kernel.sampling
import
(
from
sgl_kernel.sampling
import
(
...
...
sgl-kernel/python/sgl_kernel/moe.py
View file @
6fc93575
...
@@ -64,6 +64,11 @@ def moe_fused_gate(
...
@@ -64,6 +64,11 @@ def moe_fused_gate(
def
fp8_blockwise_scaled_grouped_mm
(
def
fp8_blockwise_scaled_grouped_mm
(
output
,
output
,
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
a
,
a
,
b
,
b
,
scales_a
,
scales_a
,
...
@@ -75,9 +80,15 @@ def fp8_blockwise_scaled_grouped_mm(
...
@@ -75,9 +80,15 @@ def fp8_blockwise_scaled_grouped_mm(
layout_sfb
,
layout_sfb
,
problem_sizes
,
problem_sizes
,
expert_offsets
,
expert_offsets
,
workspace
,
):
):
torch
.
ops
.
sgl_kernel
.
fp8_blockwise_scaled_grouped_mm
.
default
(
torch
.
ops
.
sgl_kernel
.
fp8_blockwise_scaled_grouped_mm
.
default
(
output
,
output
,
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
a
,
a
,
b
,
b
,
scales_a
,
scales_a
,
...
@@ -89,4 +100,29 @@ def fp8_blockwise_scaled_grouped_mm(
...
@@ -89,4 +100,29 @@ def fp8_blockwise_scaled_grouped_mm(
layout_sfb
,
layout_sfb
,
problem_sizes
,
problem_sizes
,
expert_offsets
,
expert_offsets
,
workspace
,
)
def
prepare_moe_input
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
input_permutation
,
output_permutation
,
num_experts
,
n
,
k
,
):
torch
.
ops
.
sgl_kernel
.
prepare_moe_input
.
default
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
input_permutation
,
output_permutation
,
num_experts
,
n
,
k
,
)
)
sgl-kernel/tests/test_fp8_blockwise_moe.py
View file @
6fc93575
...
@@ -131,9 +131,20 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
...
@@ -131,9 +131,20 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
c_strides
=
torch
.
full
(
c_strides
=
torch
.
full
(
(
num_experts
,),
c_out
.
stride
(
0
),
device
=
device
,
dtype
=
torch
.
int64
(
num_experts
,),
c_out
.
stride
(
0
),
device
=
device
,
dtype
=
torch
.
int64
)
)
workspace
=
torch
.
empty
((
1024
*
1024
*
1024
),
device
=
device
,
dtype
=
torch
.
uint8
)
a_ptrs
=
torch
.
empty
((
num_experts
,),
device
=
device
,
dtype
=
torch
.
int64
)
b_ptrs
=
torch
.
empty
((
num_experts
,),
device
=
device
,
dtype
=
torch
.
int64
)
out_ptrs
=
torch
.
empty
((
num_experts
,),
device
=
device
,
dtype
=
torch
.
int64
)
a_scales_ptrs
=
torch
.
empty
((
num_experts
,),
device
=
device
,
dtype
=
torch
.
int64
)
b_scales_ptrs
=
torch
.
empty
((
num_experts
,),
device
=
device
,
dtype
=
torch
.
int64
)
fp8_blockwise_scaled_grouped_mm
(
fp8_blockwise_scaled_grouped_mm
(
c_out
,
c_out
,
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
a_stack
,
a_stack
,
b_stack
,
b_stack
,
a_scale_stack
,
a_scale_stack
,
...
@@ -145,6 +156,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
...
@@ -145,6 +156,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
layout_sfb
,
layout_sfb
,
problem_sizes
,
problem_sizes
,
expert_offsets
[:
-
1
],
expert_offsets
[:
-
1
],
workspace
,
)
)
for
g
in
range
(
num_experts
):
for
g
in
range
(
num_experts
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment