Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
eb38c7d1
Unverified
Commit
eb38c7d1
authored
Jun 02, 2025
by
Pavani Majety
Committed by
GitHub
Jun 02, 2025
Browse files
[1/2] Add Kernel support for Cutlass based Fused FP4 MoE (#6093)
Signed-off-by:
Pavani Majety
<
pmajety@nvidia.com
>
parent
df7f61ee
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1677 additions
and
22 deletions
+1677
-22
python/sglang/srt/layers/moe/cutlass_moe.py
python/sglang/srt/layers/moe/cutlass_moe.py
+178
-1
python/sglang/test/test_fp4_moe.py
python/sglang/test/test_fp4_moe.py
+247
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+2
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+21
-2
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
+431
-0
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
+23
-0
sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu
sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu
+471
-0
sgl-kernel/csrc/moe/prepare_moe_input.cu
sgl-kernel/csrc/moe/prepare_moe_input.cu
+145
-19
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+24
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+3
-0
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+77
-0
sgl-kernel/python/sgl_kernel/moe.py
sgl-kernel/python/sgl_kernel/moe.py
+55
-0
No files found.
python/sglang/srt/layers/moe/cutlass_moe.py
View file @
eb38c7d1
"""C
utlass
MoE kernel."""
"""C
UTLASS based Fused
MoE kernel
s
."""
import
functools
import
json
...
...
@@ -14,8 +14,10 @@ _is_cuda = is_cuda()
if
_is_cuda
:
import
sgl_kernel
from
sgl_kernel
import
(
cutlass_fp4_group_mm
,
fp8_blockwise_scaled_grouped_mm
,
prepare_moe_input
,
scaled_fp4_experts_quant
,
silu_and_mul
,
)
...
...
@@ -205,3 +207,178 @@ def cutlass_fused_experts(
return
(
c2
[
c_map
].
view
(
m
,
topk
,
k
)
*
topk_weights
.
view
(
m
,
topk
,
1
).
to
(
out_dtype
)
).
sum
(
dim
=
1
)
FLOAT4_E2M1_MAX
=
6.0
FLOAT8_E4M3_MAX
=
448.0
def
cutlass_moe_fp4
(
a
:
torch
.
Tensor
,
a1_gscale
:
torch
.
Tensor
,
w1_fp4
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w1_alphas
:
torch
.
Tensor
,
a2_gscale
:
torch
.
Tensor
,
w2_fp4
:
torch
.
Tensor
,
w2_blockscale
:
torch
.
Tensor
,
w2_alphas
:
torch
.
Tensor
,
ab_strides_13
:
torch
.
Tensor
,
ab_strides_2
:
torch
.
Tensor
,
c_strides_13
:
torch
.
Tensor
,
c_strides_2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
device
:
torch
.
device
,
):
"""
MoE implementation for FP4 Inputs
# Gemm 1
a: Input tensor: [m, k] (half/bfloat16)
a1_gscale: Activation scale per expert: [e] (float32)
w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
(Note: `n` is the up projection output dim, `k` is the input dim in
full precision)
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
(Block size = 16 for NVFP4)
# Gemm 2
a2_gscale: Activation scale per expert: [e]
w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3
Strides for activations, weights and output in logical number of elements.
The activations & output stride is the number of elements to the next row.
The weights stride is the number of elements to the next row per expert.
For example, if the weight is [e, n, k], then the b_stride is a tensor of
shape [e] with each element being k. Similarly for activations, if the
shape is [m, k], then the a_stride has shape [e] with each value k.
Similarly for output, if the output is [m, n], then the c_stride is a
tensor of shape [e] with each element being k.
Note: cutlass_fp4_group_mm is designed to accept the strides of
activations and weights to be the same, so it is passed in as a single
tensor.
ab_strides_13: [e] dtype: int64 [Gemm 1: Activation / Weight strides]
ab_strides_2: [e] dtype: int64 [Gemm 2: Activation / Weight strides]
c_strides_13: [e] dtype: int64 [Gemm 1: Output Strides]
c_strides_2: [e] dtype: int64 [Gemm 1: Output Strides]
topk_weights: [m, topk] dtype: float8
topk_ids: [m, topk] dtype: float8
m, n, k: Unquantized weight shapes, dtype: int
e: number of experts for the current rank, dtype: int
assumes that topk < k < n to satisfy - up/down projection expectations.
"""
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
w1_fp4
.
dtype
==
torch
.
uint8
,
"weight 1 must be uint8"
assert
w2_fp4
.
dtype
==
torch
.
uint8
,
"weight 2 must be uint8"
assert
(
w1_fp4
.
ndim
==
3
and
w2_fp4
.
ndim
==
3
and
w1_blockscale
.
ndim
==
3
and
w2_blockscale
.
ndim
==
3
),
"All Weights must be of rank 3 for cutlass_moe_fp4"
m_a
,
k_a
=
a
.
shape
e_w1
,
nx2_w1
,
half_k_w1
=
w1_fp4
.
shape
e_w2
,
k_w2
,
half_n_w2
=
w2_fp4
.
shape
assert
e_w1
==
e_w2
and
e_w1
==
e
,
(
"Number of experts must match"
,
" between weights."
,
)
assert
(
k_a
//
2
==
half_k_w1
and
k
==
k_w2
),
"Hidden size mismatch between a, w1 and w2"
assert
nx2_w1
==
n
*
2
and
half_n_w2
==
n
//
2
,
"mismatch in "
"expected `n`"
assert
m
==
m_a
,
"input shape mismatch"
assert
2
*
half_k_w1
==
k_w2
,
"Hidden size mismatch w2 and w1"
assert
a
.
dtype
in
[
torch
.
half
,
torch
.
bfloat16
],
"Invalid input dtype"
assert
(
topk_weights
.
shape
[
0
]
==
m
and
topk_ids
.
shape
[
0
]
==
m
),
"topk must be provided for each row of a"
out_dtype
=
a
.
dtype
num_topk
=
topk_ids
.
shape
[
1
]
expert_offsets
=
torch
.
empty
((
e
+
1
),
dtype
=
torch
.
int32
,
device
=
device
)
# Problem size: (num_experts, (m,2n,k))
problem_sizes1
=
torch
.
empty
((
e
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
# Problem size: (num_experts, (m,n,k))
problem_sizes2
=
torch
.
empty
((
e
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
a_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
c_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
# problem shapes should have [m, n, k]
# Note that problem sizes are based on logical number of elements.
blockscale_offsets
=
torch
.
empty
(
e
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
prepare_moe_input
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
a_map
,
c_map
,
e
,
n
,
k
,
blockscale_offsets
,
)
rep_a_fp4
,
rep_a_blockscale
=
scaled_fp4_experts_quant
(
a
,
a1_gscale
,
expert_offsets
,
blockscale_offsets
,
num_topk
,
expert_map
=
a_map
)
c1
=
cutlass_fp4_group_mm
(
rep_a_fp4
,
w1_fp4
,
rep_a_blockscale
,
w1_blockscale
,
w1_alphas
,
ab_strides_13
,
c_strides_13
,
problem_sizes1
,
expert_offsets
[:
-
1
],
blockscale_offsets
[:
-
1
],
out_dtype
,
device
,
)
del
rep_a_fp4
,
rep_a_blockscale
# hidden size dimension is split to one halfpytho sized tensor.
intermediate
=
torch
.
empty
(
(
m
*
num_topk
,
w1_fp4
.
shape
[
1
]
//
2
),
device
=
device
,
dtype
=
out_dtype
)
silu_and_mul
(
c1
,
intermediate
)
int_fp4
,
int_blockscale
=
scaled_fp4_experts_quant
(
intermediate
,
a2_gscale
,
expert_offsets
,
blockscale_offsets
,
num_topk
)
c2
=
cutlass_fp4_group_mm
(
int_fp4
,
w2_fp4
,
int_blockscale
,
w2_blockscale
,
w2_alphas
,
ab_strides_2
,
c_strides_2
,
problem_sizes2
,
expert_offsets
[:
-
1
],
blockscale_offsets
[:
-
1
],
out_dtype
,
device
,
)
del
int_fp4
,
int_blockscale
out
=
(
c2
[
c_map
].
view
(
m
,
num_topk
,
k
)
*
topk_weights
.
view
(
m
,
num_topk
,
1
).
half
()
).
sum
(
dim
=
1
)
return
out
.
to
(
dtype
=
out_dtype
)
python/sglang/test/test_fp4_moe.py
0 → 100644
View file @
eb38c7d1
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
from
sgl_kernel
import
scaled_fp4_quant
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_moe_fp4
from
sglang.srt.layers.moe.topk
import
select_experts
if
torch
.
cuda
.
get_device_capability
()
<
(
10
,
0
):
pytest
.
skip
(
reason
=
"Nvfp4 Requires compute capability of 10 or above."
,
allow_module_level
=
True
,
)
kE2M1ToFloat
=
torch
.
tensor
(
[
0.0
,
0.5
,
1.0
,
1.5
,
2.0
,
3.0
,
4.0
,
6.0
],
dtype
=
torch
.
float32
)
FLOAT8_E4M3_MAX
=
448.0
FLOAT4_E2M1_MAX
=
6.0
def
convert_swizzled_to_linear
(
a_sf_swizzled
:
torch
.
Tensor
,
m
,
k
,
block_size
):
m_tiles
=
(
m
+
128
-
1
)
//
128
f
=
block_size
*
4
k_tiles
=
(
k
+
f
-
1
)
//
f
tmp
=
torch
.
reshape
(
a_sf_swizzled
,
(
1
,
m_tiles
,
k_tiles
,
32
,
4
,
4
))
tmp
=
torch
.
permute
(
tmp
,
(
0
,
1
,
4
,
3
,
2
,
5
))
out
=
tmp
.
reshape
(
m_tiles
*
128
,
k_tiles
*
f
//
block_size
)
return
out
[
0
:
m
,
0
:
k
]
def
dequantize_nvfp4_to_dtype
(
tensor_fp4
,
tensor_sf
,
global_scale
,
dtype
,
device
,
block_size
=
16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert
tensor_fp4
.
dtype
==
torch
.
uint8
m
,
packed_k
=
tensor_fp4
.
shape
k
=
packed_k
*
2
tensor_f32
=
break_fp4_bytes
(
tensor_fp4
,
dtype
)
tensor_f32
=
tensor_f32
.
reshape
(
m
,
k
//
block_size
,
block_size
)
tensor_sf
=
tensor_sf
.
view
(
torch
.
float8_e4m3fn
)
tensor_sf
=
convert_swizzled_to_linear
(
tensor_sf
,
m
,
k
,
block_size
)
tensor_sf_dtype
=
tensor_sf
.
to
(
torch
.
float32
)
/
global_scale
# scale the tensor
out
=
(
tensor_f32
*
tensor_sf_dtype
.
unsqueeze
(
-
1
)).
reshape
(
m
,
k
)
return
out
.
to
(
dtype
=
dtype
)
def
break_fp4_bytes
(
a
,
dtype
):
assert
a
.
dtype
==
torch
.
uint8
m
,
n
=
a
.
shape
# Vectorized nibble processing
a_flat
=
a
.
flatten
()
high
=
(
a_flat
&
0xF0
)
>>
4
# Upper nibbles
low
=
a_flat
&
0x0F
# Lower nibbles
# Combine nibbles for batch processing
combined
=
torch
.
stack
((
low
,
high
),
dim
=
1
).
flatten
()
# Vectorized sign and magnitude extraction
signs
=
(
combined
&
0x08
).
to
(
torch
.
bool
)
# Sign bits
abs_vals
=
(
combined
&
0x07
).
to
(
torch
.
long
)
# Magnitude indices
# Device-aware lookup and sign application
kE2M1
=
kE2M1ToFloat
.
to
(
device
=
a
.
device
)
values
=
kE2M1
[
abs_vals
]
*
torch
.
where
(
signs
,
-
1.0
,
1.0
)
# Reshape to final form
return
values
.
reshape
(
m
,
n
*
2
).
to
(
dtype
=
dtype
)
MNK_FACTORS
=
[
(
2
,
1024
,
1024
),
(
2
,
1024
,
1536
),
(
2
,
3072
,
1024
),
(
2
,
3072
,
1536
),
(
64
,
1024
,
1024
),
(
64
,
1024
,
1536
),
(
64
,
3072
,
1024
),
(
64
,
2048
,
1024
),
(
224
,
1024
,
1024
),
(
224
,
1024
,
1536
),
]
# Reference implementation of torch_moe
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
expert_map
):
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
if
expert_map
is
not
None
:
topk_ids
=
expert_map
[
topk_ids
]
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
out
[
mask
]
=
SiluAndMul
()(
a
[
mask
]
@
w1
[
i
].
transpose
(
0
,
1
))
@
w2
[
i
].
transpose
(
0
,
1
)
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)
).
sum
(
dim
=
1
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
[
40
,
64
,
256
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
1
,
6
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
half
,
torch
.
bfloat16
])
@
torch
.
inference_mode
()
def
test_cutlass_fp4_moe_no_graph
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
):
torch
.
manual_seed
(
7
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
quant_blocksize
=
16
round_up
=
lambda
x
,
y
:
(
x
+
y
-
1
)
//
y
*
y
sf_w1_2n
=
round_up
(
2
*
n
,
128
)
sf_w1_k
=
round_up
(
k
//
quant_blocksize
,
4
)
w1_blockscale
=
torch
.
empty
(
(
e
,
sf_w1_2n
,
sf_w1_k
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
)
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
sf_w2_k
=
round_up
(
k
,
128
)
sf_w2_n
=
round_up
(
n
//
quant_blocksize
,
4
)
w2_blockscale
=
torch
.
empty
(
(
e
,
sf_w2_k
,
sf_w2_n
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
)
w1_q
=
torch
.
empty
((
e
,
2
*
n
,
k
//
2
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
w2_q
=
torch
.
empty
((
e
,
k
,
n
//
2
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
w1_gs
=
torch
.
empty
((
e
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
w2_gs
=
torch
.
empty
((
e
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
for
expert
in
range
(
e
):
w1_amax
=
torch
.
abs
(
w1
).
max
().
to
(
torch
.
float32
)
w2_amax
=
torch
.
abs
(
w2
).
max
().
to
(
torch
.
float32
)
w1_gs
[
expert
]
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w1_amax
w2_gs
[
expert
]
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w2_amax
w1_q
[
expert
],
w1_blockscale
[
expert
]
=
scaled_fp4_quant
(
w1
[
expert
],
w1_gs
[
expert
]
)
w2_q
[
expert
],
w2_blockscale
[
expert
]
=
scaled_fp4_quant
(
w2
[
expert
],
w2_gs
[
expert
]
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
a
,
router_logits
=
score
,
top_k
=
topk
,
use_grouped_topk
=
False
,
renormalize
=
False
,
)
a1_gs
=
torch
.
ones
((
e
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
a2_gs
=
torch
.
ones
((
e
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
# strides for the cutlass moe_fp4 kernel
ab_strides_13
=
torch
.
full
(
(
e
,),
w1_q
.
shape
[
2
]
*
2
,
dtype
=
torch
.
int64
,
device
=
w1_q
.
device
)
c_strides_13
=
torch
.
full
(
(
e
,),
w1_q
.
shape
[
1
],
dtype
=
torch
.
int64
,
device
=
w1_q
.
device
)
ab_strides_2
=
torch
.
full
(
(
e
,),
w2_q
.
shape
[
2
]
*
2
,
dtype
=
torch
.
int64
,
device
=
w2_q
.
device
)
c_strides_2
=
torch
.
full
((
e
,),
w2_q
.
shape
[
1
],
dtype
=
torch
.
int64
,
device
=
w2_q
.
device
)
cutlass_output
=
cutlass_moe_fp4
(
a
=
a
,
a1_gscale
=
a1_gs
,
w1_fp4
=
w1_q
,
w1_blockscale
=
w1_blockscale
,
w1_alphas
=
(
1
/
w1_gs
),
a2_gscale
=
a2_gs
,
w2_fp4
=
w2_q
,
w2_blockscale
=
w2_blockscale
,
w2_alphas
=
(
1
/
w2_gs
),
ab_strides_13
=
ab_strides_13
,
ab_strides_2
=
ab_strides_2
,
c_strides_13
=
c_strides_13
,
c_strides_2
=
c_strides_2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
m
,
n
=
n
,
k
=
k
,
e
=
e
,
device
=
a
.
device
,
)
# Reference check:
a_global_scale
=
(
(
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
a
.
flatten
(),
dim
=-
1
)
).
to
(
torch
.
float32
)
a_fp4
,
a_scale_interleaved
=
scaled_fp4_quant
(
a
,
a_global_scale
)
_
,
m_k
=
a_fp4
.
shape
a_in_dtype
=
dequantize_nvfp4_to_dtype
(
a_fp4
,
a_scale_interleaved
,
a_global_scale
,
dtype
=
a
.
dtype
,
device
=
a
.
device
,
block_size
=
quant_blocksize
,
)
w1_d
=
torch
.
empty
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
w2_d
=
torch
.
empty
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
for
idx
in
range
(
0
,
e
):
w1_d
[
idx
]
=
dequantize_nvfp4_to_dtype
(
w1_q
[
idx
],
w1_blockscale
[
idx
],
w1_gs
[
idx
],
dtype
=
w1
.
dtype
,
device
=
w1
.
device
,
block_size
=
quant_blocksize
,
)
w2_d
[
idx
]
=
dequantize_nvfp4_to_dtype
(
w2_q
[
idx
],
w2_blockscale
[
idx
],
w2_gs
[
idx
],
dtype
=
w2
.
dtype
,
device
=
w2
.
device
,
block_size
=
quant_blocksize
,
)
torch_output
=
torch_moe
(
a_in_dtype
,
w1_d
,
w2_d
,
score
,
topk
,
None
)
torch
.
testing
.
assert_close
(
torch_output
,
cutlass_output
,
atol
=
1e-1
,
rtol
=
1e-1
)
if
__name__
==
"__main__"
:
test_cutlass_fp4_moe_no_graph
(
224
,
1024
,
1024
,
256
,
8
,
torch
.
half
)
sgl-kernel/CMakeLists.txt
View file @
eb38c7d1
...
...
@@ -210,6 +210,7 @@ set(SOURCES
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
"csrc/gemm/nvfp4_expert_quant.cu"
"csrc/gemm/nvfp4_quant_entry.cu"
"csrc/gemm/nvfp4_quant_kernels.cu"
"csrc/gemm/nvfp4_scaled_mm_entry.cu"
...
...
@@ -222,6 +223,7 @@ set(SOURCES
"csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
"csrc/moe/nvfp4_blockwise_moe.cu"
"csrc/moe/fp8_blockwise_moe_kernel.cu"
"csrc/moe/prepare_moe_input.cu"
"csrc/moe/ep_moe_reorder_kernel.cu"
...
...
sgl-kernel/csrc/common_extension.cc
View file @
eb38c7d1
...
...
@@ -132,6 +132,20 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" Tensor! output_scale, Tensor! input_scale) -> ()"
);
m
.
impl
(
"scaled_fp4_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_quant
);
// Compute NVFP4 experts quantization.
m
.
def
(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()"
);
m
.
impl
(
"scaled_fp4_experts_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_experts_quant
);
m
.
def
(
"cutlass_fp4_group_mm(Tensor! output, Tensor a, Tensor b,"
"Tensor a_blockscale, Tensor b_blockscale, Tensor alphas,"
"Tensor ab_strides, Tensor c_strides, Tensor problem_sizes,"
" Tensor expert_offsets, Tensor sf_offsets) -> ()"
);
m
.
impl
(
"cutlass_fp4_group_mm"
,
torch
::
kCUDA
,
&
cutlass_fp4_group_mm
);
/*
* From csrc/moe
*/
...
...
@@ -161,9 +175,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"expert_offsets, Tensor workspace) -> ()"
);
m
.
impl
(
"fp8_blockwise_scaled_grouped_mm"
,
torch
::
kCUDA
,
&
fp8_blockwise_scaled_grouped_mm
);
m
.
def
(
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor problem_sizes1, Tensor problem_sizes2, Tensor "
"input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> ()"
);
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor? blockscale_offsets, Tensor problem_sizes1,"
" Tensor problem_sizes2, Tensor input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> "
"()"
);
m
.
impl
(
"prepare_moe_input"
,
torch
::
kCUDA
,
&
prepare_moe_input
);
m
.
def
(
"shuffle_rows(Tensor input, Tensor dst2src_map, Tensor output) -> ()"
);
m
.
impl
(
"shuffle_rows"
,
torch
::
kCUDA
,
&
shuffle_rows
);
/*
* From csrc/speculative
*/
...
...
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
0 → 100644
View file @
eb38c7d1
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <torch/all.h>
template
<
typename
T
>
struct
TypeConverter
{
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
half
;
};
template
<
>
struct
TypeConverter
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat16
>
{
using
Type
=
__nv_bfloat162
;
};
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP4_SF_VEC_SIZE
=
16
;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
// PTX instructions used here requires sm100a.
#if CUDA_VERSION >= 12080
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
]),
"f"
(
array
[
1
]),
"f"
(
array
[
2
]),
"f"
(
array
[
3
]),
"f"
(
array
[
4
]),
"f"
(
array
[
5
]),
"f"
(
array
[
6
]),
"f"
(
array
[
7
]));
return
val
;
#else
return
0
;
#endif
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
// PTX instructions used here requires sm100a.
#if CUDA_VERSION >= 12080
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
return
val
;
#else
return
0
;
#endif
#endif
}
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
template
<
class
SFType
,
int
CVT_FP4_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
cvt_quant_to_fp4_get_sf_out_offset
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CVT_FP4_NUM_THREADS_PER_SF
==
1
||
CVT_FP4_NUM_THREADS_PER_SF
==
2
);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP4_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (16 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP4_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 16.
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
#endif
return
nullptr
;
}
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
template
<
>
struct
PackedVec
<
__nv_fp8_e4m3
>
{
__nv_fp8x2_e4m3
elts
[
8
];
};
// Quantizes the provided PackedVec into the uint32_t output
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__device__
uint32_t
cvt_warp_fp16_to_fp4
(
PackedVec
<
Type
>&
vec
,
float
SFScaleVal
,
uint8_t
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto
localMax
=
__habs2
(
vec
.
elts
[
0
]);
// Local maximum value.
#pragma unroll
for
(
int
i
=
1
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
localMax
=
__hmax2
(
localMax
,
__habs2
(
vec
.
elts
[
i
]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
localMax
,
1
),
localMax
);
// Get the final absolute maximum values.
float
vecMax
=
float
(
__hmax
(
localMax
.
x
,
localMax
.
y
));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float
SFValue
=
SFScaleVal
*
(
vecMax
*
reciprocal_approximate_ftz
(
6.0
f
));
// 8 bits representation of the SF.
uint8_t
fp8SFVal
;
// Write the SF to global memory (STG.8).
if
constexpr
(
UE8M0_SF
)
{
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t
tmp
=
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
>>
23
;
fp8SFVal
=
tmp
&
0xff
;
// Convert back to fp32.
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
=
tmp
<<
23
;
}
else
{
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3
tmp
=
__nv_fp8_e4m3
(
SFValue
);
reinterpret_cast
<
__nv_fp8_e4m3
&>
(
fp8SFVal
)
=
tmp
;
// Convert back to fp32.
SFValue
=
float
(
tmp
);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float
outputScale
=
SFValue
!=
0
?
reciprocal_approximate_ftz
(
SFValue
*
reciprocal_approximate_ftz
(
SFScaleVal
))
:
0.0
f
;
if
(
SFout
)
{
// Write the SF to global memory (STG.8).
*
SFout
=
fp8SFVal
;
}
// Convert the input to float.
float2
fp2Vals
[
CVT_FP4_ELTS_PER_THREAD
/
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
half
>
)
{
fp2Vals
[
i
]
=
__half22float2
(
vec
.
elts
[
i
]);
}
else
{
fp2Vals
[
i
]
=
__bfloat1622float2
(
vec
.
elts
[
i
]);
}
fp2Vals
[
i
].
x
*=
outputScale
;
fp2Vals
[
i
].
y
*=
outputScale
;
}
// Convert to e2m1 values.
uint32_t
e2m1Vec
=
fp32_vec_to_e2m1
(
fp2Vals
);
// Write the e2m1 values to global memory.
return
e2m1Vec
;
#else
return
0
;
#endif
}
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__global__
void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
512
,
4
)
cvt_fp16_to_fp4
(
#else
cvt_fp16_to_fp4
(
#endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// Input tensor row/col loops.
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
numRows
;
rowIdx
+=
gridDim
.
x
)
{
for
(
int
colIdx
=
threadIdx
.
x
;
colIdx
<
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
colIdx
+=
blockDim
.
x
)
{
int64_t
inOffset
=
rowIdx
*
(
numCols
/
CVT_FP4_ELTS_PER_THREAD
)
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t
outOffset
=
inOffset
;
auto
&
out_pos
=
out
[
outOffset
];
// Find index within the experts.
int
rowIdx_in_expert
=
0
;
int
expert_idx
=
0
;
for
(
int
i
=
0
;
i
<
n_experts
;
i
++
)
{
if
(
rowIdx
>=
input_offset_by_experts
[
i
]
&&
rowIdx
<
input_offset_by_experts
[
i
+
1
])
{
rowIdx_in_expert
=
rowIdx
-
input_offset_by_experts
[
i
];
expert_idx
=
i
;
break
;
}
}
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float
const
SFScaleVal
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
expert_idx
];
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
// The actual output_scales dim is computed from the padded numCols.
int32_t
numCols_padded
=
(
numCols
+
factor
-
1
)
/
factor
*
factor
;
int
numCols_SFout
=
numCols_padded
/
CVT_FP4_SF_VEC_SIZE
/
4
;
uint32_t
*
SFout_in_expert
=
SFout
+
output_scale_offset_by_experts
[
expert_idx
]
*
numCols_SFout
;
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx_in_expert
,
colIdx
,
numCols
,
SFout_in_expert
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
}
}
#endif
}
template
<
typename
T
>
void
quant_impl
(
void
*
output
,
void
*
output_scale
,
void
*
input
,
void
*
input_global_scale
,
void
*
input_offset_by_experts
,
void
*
output_scale_offset_by_experts
,
int
m_topk
,
int
k
,
int
n_experts
,
cudaStream_t
stream
)
{
// TODO: this multiProcessorCount should be cached.
int
device
;
cudaGetDevice
(
&
device
);
int
multiProcessorCount
;
cudaDeviceGetAttribute
(
&
multiProcessorCount
,
cudaDevAttrMultiProcessorCount
,
device
);
// Grid, Block size.
// Each thread converts 8 values.
dim3
block
(
std
::
min
(
int
(
k
/
ELTS_PER_THREAD
),
512
));
// Get number of blocks per SM (assume we can fully utilize the SM).
int
const
numBlocksPerSM
=
2048
/
block
.
x
;
dim3
grid
(
std
::
min
(
int
(
m_topk
),
multiProcessorCount
*
numBlocksPerSM
));
cvt_fp16_to_fp4
<
T
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
float
*>
(
input_global_scale
),
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
);
}
/*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m);
// constexpr auto FP8 = at::ScalarType::Float8_e4m3fn;
constexpr
auto
HALF
=
at
::
ScalarType
::
Half
;
constexpr
auto
BF16
=
at
::
ScalarType
::
BFloat16
;
constexpr
auto
FLOAT
=
at
::
ScalarType
::
Float
;
constexpr
auto
INT
=
at
::
ScalarType
::
Int
;
constexpr
auto
UINT8
=
at
::
ScalarType
::
Byte
;
void
scaled_fp4_experts_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
)
{
CHECK_INPUT
(
output
,
"output must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale
,
"output_scale must be a CUDA tensor"
);
CHECK_INPUT
(
input
,
"input must be a CUDA tensor"
);
CHECK_INPUT
(
input_global_scale
,
"input_global_scale must be a CUDA tensor"
);
CHECK_INPUT
(
input_offset_by_experts
,
"input_offset_by_experts must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale_offset_by_experts
,
"output_scale_offset_by_experts must be a CUDA tensor"
);
TORCH_CHECK
(
output
.
dim
()
==
2
);
TORCH_CHECK
(
output_scale
.
dim
()
==
2
);
TORCH_CHECK
(
input
.
dim
()
==
2
);
TORCH_CHECK
(
input_global_scale
.
dim
()
==
1
);
TORCH_CHECK
(
input_offset_by_experts
.
dim
()
==
1
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
dim
()
==
1
);
TORCH_CHECK
(
input
.
scalar_type
()
==
HALF
||
input
.
scalar_type
()
==
BF16
);
TORCH_CHECK
(
input_global_scale
.
scalar_type
()
==
FLOAT
);
TORCH_CHECK
(
input_offset_by_experts
.
scalar_type
()
==
INT
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
scalar_type
()
==
INT
);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
TORCH_CHECK
(
output
.
scalar_type
()
==
UINT8
);
TORCH_CHECK
(
output_scale
.
scalar_type
()
==
INT
);
const
int
BLOCK_SIZE
=
16
;
auto
m_topk
=
input
.
size
(
0
);
auto
k
=
input
.
size
(
1
);
TORCH_CHECK
(
k
%
BLOCK_SIZE
==
0
,
"k must be a multiple of 16"
);
auto
n_experts
=
input_global_scale
.
size
(
0
);
TORCH_CHECK
(
input_offset_by_experts
.
size
(
0
)
==
n_experts
+
1
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
size
(
0
)
==
n_experts
+
1
);
TORCH_CHECK
(
output
.
size
(
0
)
==
m_topk
);
TORCH_CHECK
(
output
.
size
(
1
)
==
k
/
2
);
int
scales_k
=
k
/
BLOCK_SIZE
;
// 4 means the swizzle requirement by nvidia nvfp4.
int
padded_k
=
(
scales_k
+
(
4
-
1
))
/
4
*
4
;
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK
(
output_scale
.
size
(
1
)
*
4
==
padded_k
);
auto
in_dtype
=
input
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
quant_impl
<
half
>
(
output
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
m_topk
,
k
,
n_experts
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
quant_impl
<
__nv_bfloat16
>
(
output
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
m_topk
,
k
,
n_experts
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Expected input data type to be half or bfloat16"
);
}
}
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
View file @
eb38c7d1
...
...
@@ -18,6 +18,15 @@ limitations under the License.
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void
scaled_fp4_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
);
void
scaled_fp4_experts_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
#endif
void
scaled_fp4_quant
(
...
...
@@ -27,3 +36,17 @@ void scaled_fp4_quant(
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization"
);
}
void
scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
scaled_fp4_experts_quant_sm100a
(
output
,
output_scale
,
input
,
input_global_scale
,
input_offset_by_experts
,
output_scale_offset_by_experts
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 experts quantization kernel"
);
}
sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu
0 → 100644
View file @
eb38c7d1
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <cutlass/arch/arch.h>
#include <torch/all.h>
#include <cassert>
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/tensor_view_io.h"
using
namespace
cute
;
template
<
typename
ElementAB
,
typename
ElementC
,
typename
ElementSF
,
typename
ElementAccumulator
,
typename
LayoutSFA
,
typename
LayoutSFB
,
typename
ScaleConfig
>
__global__
void
__get_group_gemm_starts
(
ElementAB
**
a_offsets
,
ElementAB
**
b_offsets
,
ElementC
**
out_offsets
,
ElementSF
**
a_scales_offsets
,
ElementSF
**
b_scales_offsets
,
ElementAccumulator
**
alpha_offsets
,
LayoutSFA
*
layout_sfa_base_as_int
,
LayoutSFB
*
layout_sfb_base_as_int
,
ElementAB
*
a_base_as_int
,
ElementAB
*
b_base_as_int
,
ElementC
*
out_base_as_int
,
ElementSF
*
a_scales_base_as_int
,
ElementSF
*
b_scales_base_as_int
,
ElementAccumulator
*
alphas_base_as_int
,
const
int32_t
*
expert_offsets
,
const
int32_t
*
sf_offsets
,
const
int32_t
*
problem_sizes_as_shapes
,
const
int
K
,
const
int
N
)
{
int64_t
expert_id
=
threadIdx
.
x
;
if
(
expert_id
>=
gridDim
.
x
*
blockDim
.
x
)
{
return
;
}
// Originally int32_t but upcasting to int64_t to avoid overflow
// during offset calculations
int64_t
expert_offset
=
static_cast
<
int64_t
>
(
expert_offsets
[
expert_id
]);
int64_t
sf_offset
=
static_cast
<
int64_t
>
(
sf_offsets
[
expert_id
]);
// size for block in block scale.
int64_t
group_size
=
16
;
int64_t
m
=
static_cast
<
int64_t
>
(
problem_sizes_as_shapes
[
expert_id
*
3
]);
int64_t
n
=
static_cast
<
int64_t
>
(
problem_sizes_as_shapes
[
expert_id
*
3
+
1
]);
int64_t
k
=
static_cast
<
int64_t
>
(
problem_sizes_as_shapes
[
expert_id
*
3
+
2
]);
assert
((
m
>=
0
&&
n
==
N
&&
k
==
K
&&
k
%
2
==
0
)
&&
"unexpected problem sizes"
);
int64_t
half_k
=
static_cast
<
int64_t
>
(
k
/
2
);
int64_t
group_k
=
static_cast
<
int64_t
>
(
k
/
group_size
);
// Shape of A as uint8/byte = [M, K // 2]
// Shape of B as uint8/byte = [E, N, K // 2]
a_offsets
[
expert_id
]
=
a_base_as_int
+
expert_offset
*
half_k
;
b_offsets
[
expert_id
]
=
b_base_as_int
+
expert_id
*
n
*
half_k
;
// Shape of C = [M, N]
out_offsets
[
expert_id
]
=
out_base_as_int
+
expert_offset
*
n
;
// Shape of a_scale = [sum(sf_sizes), K // group_size]
a_scales_offsets
[
expert_id
]
=
a_scales_base_as_int
+
sf_offset
*
group_k
;
assert
((
reinterpret_cast
<
uintptr_t
>
(
a_scales_offsets
[
expert_id
])
%
128
)
==
0
&&
"TMA requires 128-byte alignment"
);
// Shape of B scale = [E, N, K // group_size]
b_scales_offsets
[
expert_id
]
=
b_scales_base_as_int
+
expert_id
*
n
*
group_k
;
assert
((
reinterpret_cast
<
uintptr_t
>
(
b_scales_offsets
[
expert_id
])
%
128
)
==
0
&&
"TMA requires 128-byte alignment"
);
// Shape of alpha = [E]
alpha_offsets
[
expert_id
]
=
alphas_base_as_int
+
expert_id
;
LayoutSFA
*
layout_sfa_ptr
=
layout_sfa_base_as_int
+
expert_id
;
LayoutSFB
*
layout_sfb_ptr
=
layout_sfb_base_as_int
+
expert_id
;
*
layout_sfa_ptr
=
ScaleConfig
::
tile_atom_to_shape_SFA
(
cute
::
make_shape
(
static_cast
<
int
>
(
m
),
static_cast
<
int
>
(
n
),
static_cast
<
int
>
(
k
),
1
));
*
layout_sfb_ptr
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
static_cast
<
int
>
(
m
),
static_cast
<
int
>
(
n
),
static_cast
<
int
>
(
k
),
1
));
}
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE( \
ELEMENT_AB_TYPE, SF_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, LayoutSFA, LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \
static_cast<ELEMENT_AB_TYPE**>(a_starts.data_ptr()), \
static_cast<ELEMENT_AB_TYPE**>(b_starts.data_ptr()), \
static_cast<C_TYPE**>(out_starts.data_ptr()), \
static_cast<SF_TYPE**>(a_scales_starts.data_ptr()), \
static_cast<SF_TYPE**>(b_scales_starts.data_ptr()), \
static_cast<float**>(alpha_starts.data_ptr()), \
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(a_tensors.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<SF_TYPE*>(a_scales.data_ptr()), \
static_cast<SF_TYPE*>(b_scales.data_ptr()), \
static_cast<float*>(alphas.data_ptr()), \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<int32_t*>(sf_offsets.data_ptr()), \
static_cast<int32_t*>(problem_sizes.data_ptr()), \
K, \
N); \
}
template
<
typename
LayoutSFA
,
typename
LayoutSFB
,
typename
ScaleConfig
>
void
run_get_group_gemm_starts
(
const
torch
::
Tensor
&
a_starts
,
const
torch
::
Tensor
&
b_starts
,
const
torch
::
Tensor
&
out_starts
,
const
torch
::
Tensor
&
a_scales_starts
,
const
torch
::
Tensor
&
b_scales_starts
,
const
torch
::
Tensor
&
alpha_starts
,
const
torch
::
Tensor
&
layout_sfa
,
const
torch
::
Tensor
&
layout_sfb
,
/*these are used for their base addresses*/
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
out_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
alphas
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
sf_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
int
M
,
int
N
,
int
K
)
{
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a_tensors
.
device
().
index
());
TORCH_CHECK
(
out_tensors
.
size
(
1
)
==
N
,
"Output tensor shape doesn't match expected shape"
);
TORCH_CHECK
(
K
/
2
==
b_tensors
.
size
(
2
),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match"
);
if
(
false
)
{
}
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
// ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE
(
cutlass
::
float_e2m1_t
,
cutlass
::
float_ue4m3_t
,
torch
::
kBFloat16
,
cutlass
::
bfloat16_t
,
LayoutSFA
,
LayoutSFB
,
ScaleConfig
)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE
(
cutlass
::
float_e2m1_t
,
cutlass
::
float_ue4m3_t
,
torch
::
kFloat16
,
half
,
LayoutSFA
,
LayoutSFB
,
ScaleConfig
)
else
{
TORCH_CHECK
(
false
,
"Invalid output type (must be float16 or bfloat16)"
);
}
}
template
<
typename
OutType
>
void
run_fp4_blockwise_scaled_group_mm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_blockscale
,
const
torch
::
Tensor
&
b_blockscales
,
const
torch
::
Tensor
&
alphas
,
const
torch
::
Tensor
&
ab_strides
,
const
torch
::
Tensor
&
c_strides
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
sf_offsets
,
int
M
,
int
N
,
int
K
)
{
using
ProblemShape
=
cutlass
::
gemm
::
GroupProblemShape
<
Shape
<
int32_t
,
int32_t
,
int32_t
>>
;
using
ElementType
=
cutlass
::
float_e2m1_t
;
using
ElementSFType
=
cutlass
::
float_ue4m3_t
;
using
ElementA
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
ElementB
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
ElementC
=
OutType
;
using
ElementD
=
ElementC
;
using
ElementAccumulator
=
float
;
// Layout definitions
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
using
LayoutD
=
LayoutC
;
// Alignment constraints
static
constexpr
int
AlignmentA
=
32
;
static
constexpr
int
AlignmentB
=
32
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
// Architecture definitions
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
EpilogueOperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
// Epilogue Operator class tag
using
MainloopOperatorClass
=
cutlass
::
arch
::
OpClassBlockScaledTensorOp
;
// Mainloop Operator class tag
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
// Stage count maximized based
// on the tile size
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
struct
MMA1SMConfig
{
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100
;
// Kernel to launch
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecialized1Sm
;
// Epilogue to launch
};
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
EpilogueOperatorClass
,
typename
MMA1SMConfig
::
MmaTileShape
,
ClusterShape
,
Shape
<
_128
,
_64
>
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutC
*
,
AlignmentC
,
ElementD
,
LayoutC
*
,
AlignmentD
,
typename
MMA1SMConfig
::
EpilogueSchedule
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
MainloopOperatorClass
,
ElementA
,
LayoutA
*
,
AlignmentA
,
ElementB
,
LayoutB
*
,
AlignmentB
,
ElementAccumulator
,
typename
MMA1SMConfig
::
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
typename
MMA1SMConfig
::
KernelSchedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
ProblemShape
,
CollectiveMainloop
,
CollectiveEpilogue
>
;
using
Gemm1SM
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
Gemm
=
Gemm1SM
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
InternalStrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
InternalStrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
InternalStrideC
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
InternalStrideD
;
using
LayoutSFA
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
InternalLayoutSFA
;
using
LayoutSFB
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
InternalLayoutSFB
;
using
ScaleConfig
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm1xxBlkScaledConfig
;
using
UnderlyingProblemShape
=
ProblemShape
::
UnderlyingProblemShape
;
int
num_experts
=
static_cast
<
int
>
(
expert_offsets
.
size
(
0
));
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
());
torch
::
Tensor
a_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
out_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
a_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
alpha_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
layout_sfa
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
torch
::
Tensor
layout_sfb
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
run_get_group_gemm_starts
<
LayoutSFA
,
LayoutSFB
,
ScaleConfig
>
(
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
alpha_ptrs
,
layout_sfa
,
layout_sfb
,
a
,
b
,
output
,
a_blockscale
,
b_blockscales
,
alphas
,
expert_offsets
,
sf_offsets
,
problem_sizes
,
M
,
N
,
K
);
// Create an instance of the GEMM
Gemm
gemm_op
;
// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape
*
problem_sizes_as_shapes
=
static_cast
<
UnderlyingProblemShape
*>
(
problem_sizes
.
data_ptr
());
// Set the Scheduler info
cutlass
::
KernelHardwareInfo
hw_info
;
using
RasterOrderOptions
=
typename
cutlass
::
gemm
::
kernel
::
detail
::
PersistentTileSchedulerSm100GroupParams
<
typename
ProblemShape
::
UnderlyingProblemShape
>::
RasterOrderOptions
;
typename
Gemm
::
GemmKernel
::
TileSchedulerArguments
scheduler
;
scheduler
.
raster_order
=
RasterOrderOptions
::
AlongM
;
hw_info
.
device_id
=
a
.
get_device
();
static
std
::
unordered_map
<
int
,
int
>
cached_sm_counts
;
if
(
cached_sm_counts
.
find
(
hw_info
.
device_id
)
==
cached_sm_counts
.
end
())
{
cached_sm_counts
[
hw_info
.
device_id
]
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
}
hw_info
.
sm_count
=
min
(
cached_sm_counts
[
hw_info
.
device_id
],
INT_MAX
);
// Mainloop Arguments
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
static_cast
<
const
ElementType
**>
(
a_ptrs
.
data_ptr
()),
static_cast
<
StrideA
*>
(
ab_strides
.
data_ptr
()),
static_cast
<
const
ElementType
**>
(
b_ptrs
.
data_ptr
()),
static_cast
<
StrideB
*>
(
ab_strides
.
data_ptr
()),
static_cast
<
const
ElementSFType
**>
(
a_scales_ptrs
.
data_ptr
()),
reinterpret_cast
<
LayoutSFA
*>
(
layout_sfa
.
data_ptr
()),
static_cast
<
const
ElementSFType
**>
(
b_scales_ptrs
.
data_ptr
()),
reinterpret_cast
<
LayoutSFB
*>
(
layout_sfb
.
data_ptr
())};
// Epilogue Arguments
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
// epilogue.thread
nullptr
,
static_cast
<
StrideC
*>
(
c_strides
.
data_ptr
()),
static_cast
<
ElementD
**>
(
out_ptrs
.
data_ptr
()),
static_cast
<
StrideC
*>
(
c_strides
.
data_ptr
())};
auto
&
fusion_args
=
epilogue_args
.
thread
;
fusion_args
.
alpha_ptr_array
=
reinterpret_cast
<
float
**>
(
alpha_ptrs
.
data_ptr
());
fusion_args
.
dAlpha
=
{
_0
{},
_0
{},
1
};
// Gemm Arguments
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGrouped
,
{
num_experts
,
problem_sizes_as_shapes
,
nullptr
},
mainloop_args
,
epilogue_args
,
hw_info
,
scheduler
};
size_t
workspace_size
=
Gemm
::
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
can_implement_status
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to implement GEMM"
);
// Run the GEMM
auto
status
=
gemm_op
.
initialize
(
args
,
workspace
.
data_ptr
());
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to initialize GEMM"
);
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run GEMM"
);
}
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
void
cutlass_fp4_group_mm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_blockscale
,
const
torch
::
Tensor
&
b_blockscales
,
const
torch
::
Tensor
&
alphas
,
const
torch
::
Tensor
&
ab_strides
,
const
torch
::
Tensor
&
c_strides
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
sf_offsets
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
constexpr
auto
FLOAT4_E2M1X2
=
at
::
ScalarType
::
Byte
;
constexpr
auto
SF_DTYPE
=
at
::
ScalarType
::
Float8_e4m3fn
;
// Input validation
CHECK_INPUT
(
a
,
FLOAT4_E2M1X2
,
"a"
);
CHECK_INPUT
(
b
,
FLOAT4_E2M1X2
,
"b"
);
CHECK_INPUT
(
a_blockscale
,
SF_DTYPE
,
"a_blockscale"
);
CHECK_INPUT
(
b_blockscales
,
SF_DTYPE
,
"b_blockscales"
);
CHECK_INPUT
(
alphas
,
at
::
ScalarType
::
Float
,
"alphas"
);
TORCH_CHECK
(
a_blockscale
.
dim
()
==
2
,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: "
,
a_blockscale
.
dim
())
TORCH_CHECK
(
b_blockscales
.
dim
()
==
3
,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: "
,
b_blockscales
.
dim
())
TORCH_CHECK
(
problem_sizes
.
dim
()
==
2
,
"problem_sizes must be a 2D tensor"
);
TORCH_CHECK
(
problem_sizes
.
size
(
1
)
==
3
,
"problem_sizes must have the shape (num_experts, 3)"
);
TORCH_CHECK
(
problem_sizes
.
size
(
0
)
==
expert_offsets
.
size
(
0
),
"Number of experts in problem_sizes must match expert_offsets"
);
TORCH_CHECK
(
problem_sizes
.
dtype
()
==
torch
::
kInt32
,
"problem_sizes must be int32."
);
int
M
=
static_cast
<
int
>
(
a
.
size
(
0
));
int
N
=
static_cast
<
int
>
(
b
.
size
(
1
));
int
E
=
static_cast
<
int
>
(
b
.
size
(
0
));
int
K
=
static_cast
<
int
>
(
2
*
b
.
size
(
2
));
if
(
output
.
scalar_type
()
==
torch
::
kBFloat16
)
{
run_fp4_blockwise_scaled_group_mm
<
cutlass
::
bfloat16_t
>
(
output
,
a
,
b
,
a_blockscale
,
b_blockscales
,
alphas
,
ab_strides
,
c_strides
,
problem_sizes
,
expert_offsets
,
sf_offsets
,
M
,
N
,
K
);
}
else
{
run_fp4_blockwise_scaled_group_mm
<
cutlass
::
half_t
>
(
output
,
a
,
b
,
a_blockscale
,
b_blockscales
,
alphas
,
ab_strides
,
c_strides
,
problem_sizes
,
expert_offsets
,
sf_offsets
,
M
,
N
,
K
);
}
#else
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_fp4_group_mm kernel, sgl-kernel must "
"be compiled with ENABLE_NVFP4 for SM100+ and CUDA "
"12.8 or above."
);
#endif
}
sgl-kernel/csrc/moe/prepare_moe_input.cu
View file @
eb38c7d1
...
...
@@ -4,6 +4,8 @@
#include <iostream>
#include "cutlass/array.h"
constexpr
uint64_t
THREADS_PER_EXPERT
=
512
;
__global__
void
compute_problem_sizes
(
...
...
@@ -11,9 +13,9 @@ __global__ void compute_problem_sizes(
int32_t
*
problem_sizes1
,
int32_t
*
problem_sizes2
,
int32_t
*
atomic_buffer
,
const
int
topk_length
,
const
int
n
,
const
int
k
)
{
const
int
64_t
topk_length
,
const
int
64_t
n
,
const
int
64_t
k
)
{
int
expert_id
=
blockIdx
.
x
;
int
occurrences
=
0
;
...
...
@@ -26,11 +28,11 @@ __global__ void compute_problem_sizes(
if
(
threadIdx
.
x
==
0
)
{
int
final_occurrences
=
atomic_buffer
[
expert_id
];
problem_sizes1
[
expert_id
*
3
]
=
final_occurrences
;
problem_sizes1
[
expert_id
*
3
+
1
]
=
2
*
n
;
problem_sizes1
[
expert_id
*
3
+
2
]
=
k
;
problem_sizes1
[
expert_id
*
3
+
1
]
=
static_cast
<
int32_t
>
(
2
*
n
)
;
problem_sizes1
[
expert_id
*
3
+
2
]
=
static_cast
<
int32_t
>
(
k
)
;
problem_sizes2
[
expert_id
*
3
]
=
final_occurrences
;
problem_sizes2
[
expert_id
*
3
+
1
]
=
k
;
problem_sizes2
[
expert_id
*
3
+
2
]
=
n
;
problem_sizes2
[
expert_id
*
3
+
1
]
=
static_cast
<
int32_t
>
(
k
)
;
problem_sizes2
[
expert_id
*
3
+
2
]
=
static_cast
<
int32_t
>
(
n
)
;
}
}
...
...
@@ -38,7 +40,7 @@ __global__ void compute_expert_offsets(
const
int32_t
*
__restrict__
problem_sizes1
,
int32_t
*
expert_offsets
,
int32_t
*
atomic_buffer
,
const
int
num_experts
)
{
const
int
64_t
num_experts
)
{
int32_t
tot_offset
=
0
;
expert_offsets
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
...
...
@@ -48,13 +50,34 @@ __global__ void compute_expert_offsets(
}
}
__global__
void
compute_expert_blockscale_offsets
(
const
int32_t
*
__restrict__
problem_sizes1
,
int32_t
*
expert_offsets
,
int32_t
*
blockscale_offsets
,
int32_t
*
atomic_buffer
,
const
int64_t
num_experts
)
{
int32_t
tot_offset
=
0
;
int32_t
tot_rounded_offset
=
0
;
expert_offsets
[
0
]
=
0
;
blockscale_offsets
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
atomic_buffer
[
i
]
=
tot_offset
;
int
num_tokens
=
problem_sizes1
[
i
*
3
];
int
rounded_num_tokens
=
(
num_tokens
+
(
128
-
1
))
/
128
*
128
;
tot_offset
+=
num_tokens
;
tot_rounded_offset
+=
rounded_num_tokens
;
expert_offsets
[
i
+
1
]
=
tot_offset
;
blockscale_offsets
[
i
+
1
]
=
tot_rounded_offset
;
}
}
__global__
void
compute_arg_sorts
(
const
int
*
__restrict__
topk_ids
,
const
int
32_t
*
__restrict__
topk_ids
,
int32_t
*
input_permutation
,
int32_t
*
output_permutation
,
int32_t
*
atomic_buffer
,
const
int
topk_length
,
const
int
topk
)
{
const
int
64_t
topk_length
,
const
int
64_t
topk
)
{
int
expert_id
=
blockIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
topk_length
;
i
+=
THREADS_PER_EXPERT
)
{
...
...
@@ -69,6 +92,7 @@ __global__ void compute_arg_sorts(
void
get_moe_prepare_input_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
...
...
@@ -80,8 +104,10 @@ void get_moe_prepare_input_caller(
auto
options_int32
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
topk_ids
.
device
());
torch
::
Tensor
atomic_buffer
=
torch
::
zeros
(
num_experts
,
options_int32
);
int
num_threads
=
min
(
THREADS_PER_EXPERT
,
topk_ids
.
numel
());
compute_problem_sizes
<<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
uint32_t
num_threads
=
static_cast
<
uint32_t
>
(
min
(
THREADS_PER_EXPERT
,
topk_ids
.
numel
()));
uint32_t
num_blocks
=
static_cast
<
uint32_t
>
(
num_experts
);
compute_problem_sizes
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes2
.
data_ptr
()),
...
...
@@ -89,12 +115,21 @@ void get_moe_prepare_input_caller(
topk_ids
.
numel
(),
n
,
k
);
compute_expert_offsets
<<<
1
,
1
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
num_experts
);
compute_arg_sorts
<<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
if
(
blockscale_offsets
.
has_value
())
{
compute_expert_blockscale_offsets
<<<
1
,
1
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
int32_t
*>
(
blockscale_offsets
.
value
().
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
num_experts
);
}
else
{
compute_expert_offsets
<<<
1
,
1
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
num_experts
);
}
compute_arg_sorts
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
int32_t
*>
(
input_permutation
.
data_ptr
()),
static_cast
<
int32_t
*>
(
output_permutation
.
data_ptr
()),
...
...
@@ -106,6 +141,7 @@ void get_moe_prepare_input_caller(
void
prepare_moe_input
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
...
...
@@ -117,6 +153,7 @@ void prepare_moe_input(
get_moe_prepare_input_caller
(
topk_ids
,
expert_offsets
,
blockscale_offsets
,
problem_sizes1
,
problem_sizes2
,
input_permutation
,
...
...
@@ -126,3 +163,92 @@ void prepare_moe_input(
k
);
return
;
}
template
<
typename
T
>
__global__
void
shuffleRowsKernel
(
const
T
*
input
,
const
int32_t
*
dst2src_map
,
T
*
output
,
int64_t
num_src_rows
,
int64_t
num_dst_rows
,
int64_t
num_cols
)
{
int64_t
dest_row_idx
=
blockIdx
.
x
;
int64_t
const
source_row_idx
=
dst2src_map
[
dest_row_idx
];
if
(
blockIdx
.
x
<
num_dst_rows
)
{
// Load 128-bits per thread
constexpr
uint64_t
ELEM_PER_THREAD
=
128
/
sizeof
(
T
)
/
8
;
using
DataElem
=
cutlass
::
Array
<
T
,
ELEM_PER_THREAD
>
;
// Duplicate and permute rows
auto
const
*
source_row_ptr
=
reinterpret_cast
<
DataElem
const
*>
(
input
+
source_row_idx
*
num_cols
);
auto
*
dest_row_ptr
=
reinterpret_cast
<
DataElem
*>
(
output
+
dest_row_idx
*
num_cols
);
auto
const
start_offset
=
threadIdx
.
x
;
auto
const
stride
=
blockDim
.
x
;
auto
const
num_elems_in_col
=
num_cols
/
ELEM_PER_THREAD
;
for
(
auto
elem_index
=
start_offset
;
elem_index
<
num_elems_in_col
;
elem_index
+=
stride
)
{
dest_row_ptr
[
elem_index
]
=
source_row_ptr
[
elem_index
];
}
}
}
#define DECLARE_SHUFFLE_ROWS(T) \
__global__ void shuffleRowsKernel( \
const T* input, \
const int32_t* dst2src_map, \
T* output, \
int64_t num_src_rows, \
int64_t num_dest_rows, \
int64_t num_cols);
DECLARE_SHUFFLE_ROWS
(
float
);
DECLARE_SHUFFLE_ROWS
(
half
);
DECLARE_SHUFFLE_ROWS
(
__nv_bfloat16
);
DECLARE_SHUFFLE_ROWS
(
__nv_fp8_e4m3
);
DECLARE_SHUFFLE_ROWS
(
uint8_t
);
#define SHUFFLE_ROWS(T) \
shuffleRowsKernel<T><<<blocks, threads, 0, stream>>>( \
reinterpret_cast<const T*>(input), \
static_cast<const int32_t*>(dst2src_map.data_ptr()), \
reinterpret_cast<T*>(output), \
num_src_rows, \
num_dst_rows, \
num_cols)
#define DTYPE_DISPATCH_CASE(T, CUDA_T) \
case T: \
SHUFFLE_ROWS(CUDA_T); \
break;
void
shuffle_rows_caller
(
const
torch
::
Tensor
&
input_tensor
,
const
torch
::
Tensor
&
dst2src_map
,
torch
::
Tensor
&
output_tensor
)
{
TORCH_CHECK
(
input_tensor
.
scalar_type
()
==
output_tensor
.
scalar_type
(),
"Input and output tensors must have the same data type"
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
uint32_t
blocks
=
static_cast
<
uint32_t
>
(
output_tensor
.
size
(
0
));
uint32_t
threads
=
256
;
int64_t
num_dst_rows
=
output_tensor
.
size
(
0
);
int64_t
num_src_rows
=
input_tensor
.
size
(
0
);
int64_t
num_cols
=
input_tensor
.
size
(
1
);
const
void
*
input
=
input_tensor
.
data_ptr
();
void
*
output
=
output_tensor
.
data_ptr
();
switch
(
input_tensor
.
scalar_type
())
{
DTYPE_DISPATCH_CASE
(
torch
::
kFloat16
,
half
);
DTYPE_DISPATCH_CASE
(
torch
::
kBFloat16
,
__nv_bfloat16
);
DTYPE_DISPATCH_CASE
(
torch
::
kFloat32
,
float
);
DTYPE_DISPATCH_CASE
(
torch
::
kFloat8_e4m3fn
,
__nv_fp8_e4m3
);
DTYPE_DISPATCH_CASE
(
torch
::
kUInt8
,
uint8_t
);
default:
TORCH_CHECK
(
false
,
"[moe replicate input] data type dispatch fail!"
);
}
return
;
}
void
shuffle_rows
(
const
torch
::
Tensor
&
input_tensor
,
const
torch
::
Tensor
&
dst2src_map
,
torch
::
Tensor
&
output_tensor
)
{
shuffle_rows_caller
(
input_tensor
,
dst2src_map
,
output_tensor
);
return
;
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
eb38c7d1
...
...
@@ -232,6 +232,7 @@ void fp8_blockwise_scaled_grouped_mm(
void
prepare_moe_input
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
...
...
@@ -251,6 +252,29 @@ void ep_moe_pre_reorder(
int64_t
topk
,
bool
use_per_token_if_dynamic
);
void
shuffle_rows
(
const
torch
::
Tensor
&
input_tensor
,
const
torch
::
Tensor
&
dst2src_map
,
torch
::
Tensor
&
output_tensor
);
void
cutlass_fp4_group_mm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_blockscale
,
const
torch
::
Tensor
&
b_blockscales
,
const
torch
::
Tensor
&
alphas
,
const
torch
::
Tensor
&
ab_strides
,
const
torch
::
Tensor
&
c_strides
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
sf_offsets
);
void
scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
/*
* From csrc/speculative
*/
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
eb38c7d1
...
...
@@ -38,14 +38,17 @@ from sgl_kernel.gemm import (
int8_scaled_mm
,
qserve_w4a8_per_chn_gemm
,
qserve_w4a8_per_group_gemm
,
scaled_fp4_experts_quant
,
scaled_fp4_quant
,
sgl_per_tensor_quant_fp8
,
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_int8
,
sgl_per_token_quant_fp8
,
shuffle_rows
,
)
from
sgl_kernel.grammar
import
apply_token_bitmask_inplace_cuda
from
sgl_kernel.moe
import
(
cutlass_fp4_group_mm
,
ep_moe_pre_reorder
,
fp8_blockwise_scaled_grouped_mm
,
moe_align_block_size
,
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
eb38c7d1
...
...
@@ -241,3 +241,80 @@ def qserve_w4a8_per_group_gemm(
in_feats
,
kernel
,
zeros
,
scales_i8
,
wscales
,
ascales
,
out_feats
)
return
out_feats
def
shuffle_rows
(
input_tensor
,
dst2src_map
,
output_tensor_shape
):
output_tensor
=
torch
.
empty
(
output_tensor_shape
,
device
=
input_tensor
.
device
,
dtype
=
input_tensor
.
dtype
,
)
torch
.
ops
.
sgl_kernel
.
shuffle_rows
.
default
(
input_tensor
,
dst2src_map
,
output_tensor
)
return
output_tensor
def
scaled_fp4_experts_quant
(
input_tensor
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
expert_offsets
:
torch
.
Tensor
,
blockscale_offsets
:
torch
.
Tensor
,
topk
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
packed MoE Inputs.
Args:
input: The input tensor to be quantized to FP4
expert_map: The expert map tensor
input_global_scale: A scalar scaling factor for the entire tensor.
expert_offsets: The expert offsets tensor
blockscale_offsets: The blockscale offsets tensor
Outputs:
output: The quantized tensor in FP4
output_scales: The blockscale tensor in FP8-E4M3
"""
assert
(
input_tensor
.
ndim
==
2
),
f
"input.ndim needs to be == 2, but got
{
input_tensor
.
ndim
}
."
if
expert_map
is
not
None
:
(
m
,
k
)
=
input_tensor
.
shape
output_tensor_shape
=
(
m
*
topk
,
k
)
input_tensor
=
shuffle_rows
(
input_tensor
,
expert_map
,
output_tensor_shape
)
m_numtopk
,
k
=
input_tensor
.
shape
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
# from running out of memory. This value can also be increased to support
# larger models.
import
os
MAX_TOKENS_PER_EXPERT
=
os
.
environ
.
get
(
"MODELOPT_MAX_TOKENS_PER_EXPERT"
,
65536
)
assert
m_numtopk
<=
MAX_TOKENS_PER_EXPERT
*
topk
,
(
f
"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
f
"
{
MAX_TOKENS_PER_EXPERT
}
)"
f
" for cutlass_moe_fp4, observed m_numtopk =
{
m_numtopk
}
. Use"
f
" MODELOPT_MAX_TOKENS_PER_EXPERT to set this value."
)
scales_k
=
k
//
16
padded_k
=
(
scales_k
+
(
4
-
1
))
//
4
# output is uint8 and packed fp4 values
output
=
torch
.
empty
(
m_numtopk
,
k
//
2
,
device
=
input_tensor
.
device
,
dtype
=
torch
.
uint8
)
output_scales
=
torch
.
empty
(
MAX_TOKENS_PER_EXPERT
*
topk
,
padded_k
,
dtype
=
torch
.
int32
,
device
=
input_tensor
.
device
,
)
torch
.
ops
.
sgl_kernel
.
scaled_fp4_experts_quant
.
default
(
output
,
output_scales
,
input_tensor
,
input_global_scale
,
expert_offsets
,
blockscale_offsets
,
)
output_scales
=
output_scales
.
view
(
torch
.
float8_e4m3fn
)
return
output
,
output_scales
sgl-kernel/python/sgl_kernel/moe.py
View file @
eb38c7d1
from
typing
import
Optional
import
torch
...
...
@@ -138,10 +140,12 @@ def prepare_moe_input(
num_experts
,
n
,
k
,
blockscale_offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
):
torch
.
ops
.
sgl_kernel
.
prepare_moe_input
.
default
(
topk_ids
,
expert_offsets
,
blockscale_offsets
,
problem_sizes1
,
problem_sizes2
,
input_permutation
,
...
...
@@ -150,3 +154,54 @@ def prepare_moe_input(
n
,
k
,
)
def
cutlass_fp4_group_mm
(
a_fp4
,
b_fp4
,
a_blockscale
,
b_blockscale
,
alphas
,
ab_strides
,
c_strides
,
problem_sizes
,
expert_offsets
,
blockscale_offsets
,
out_dtype
,
device
,
):
"""
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
the gemms for each combination based on the specified problem sizes.
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
input and expert weights.
- a_/b_scales: The blockscales in FP8-E4M3 precision
- ab_strides/c_strides: Strides for the a/b tensors between rows.
- expert_offsets/sf_offsets: Indices that mark at which token index
each expert begins its computation. The number of tokens
computed with expert E is expert_offsets[E + 1] -
expert_offsets[E] And the sf_size per expert is
sf_offset[E+1] - sf_offset[E]
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
MMs used in the fused MoE operation.
"""
m_topk
=
a_fp4
.
shape
[
0
]
n
=
b_fp4
.
shape
[
1
]
c_shape
=
(
m_topk
,
n
)
c
=
torch
.
empty
(
c_shape
,
device
=
device
,
dtype
=
out_dtype
)
torch
.
ops
.
sgl_kernel
.
cutlass_fp4_group_mm
.
default
(
c
,
a_fp4
,
b_fp4
,
a_blockscale
,
b_blockscale
,
alphas
,
ab_strides
,
c_strides
,
problem_sizes
,
expert_offsets
,
blockscale_offsets
,
)
return
c
.
to
(
dtype
=
out_dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment