Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5c34b4f1
"vscode:/vscode.git/clone" did not exist on "7e7376eb4a1a696046d008462d30f8ad541ed8f9"
Unverified
Commit
5c34b4f1
authored
Aug 29, 2025
by
Kaixi Hou
Committed by
GitHub
Aug 29, 2025
Browse files
[NVIDIA] [2/N] Optimize `silu_and_mul_scaled_fp4_grouped_quant` perf (#9556)
parent
ff9b5618
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
297 additions
and
61 deletions
+297
-61
benchmark/kernels/quantization/bench_fp4_quant.py
benchmark/kernels/quantization/bench_fp4_quant.py
+133
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+1
-2
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
+134
-19
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
+5
-7
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+2
-3
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+5
-21
sgl-kernel/tests/test_fp4_quantize.py
sgl-kernel/tests/test_fp4_quantize.py
+17
-9
No files found.
benchmark/kernels/quantization/bench_fp4_quant.py
0 → 100644
View file @
5c34b4f1
import
argparse
import
itertools
import
torch
import
triton
from
sgl_kernel
import
scaled_fp4_grouped_quant
,
silu_and_mul_scaled_fp4_grouped_quant
from
sgl_kernel.elementwise
import
silu_and_mul
from
sglang.srt.layers.moe.ep_moe.kernels
import
silu_and_mul_masked_post_quant_fwd
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
def
_test_accuracy_once
(
E
,
M
,
K
,
input_dtype
,
device
):
x
=
torch
.
randn
(
E
,
M
,
K
,
device
=
device
,
dtype
=
input_dtype
)
glb_scales
=
torch
.
ones
((
E
,),
dtype
=
torch
.
float32
,
device
=
device
)
masks
=
torch
.
full
((
E
,),
M
,
dtype
=
torch
.
int32
,
device
=
device
)
out
,
blk_scales
=
silu_and_mul_scaled_fp4_grouped_quant
(
x
,
glb_scales
,
masks
)
out1
,
blk_scales1
=
scaled_fp4_grouped_quant
(
silu_and_mul
(
x
),
glb_scales
,
masks
,
)
torch
.
testing
.
assert_close
(
out
,
out1
)
torch
.
testing
.
assert_close
(
blk_scales
,
blk_scales1
)
print
(
f
"E:
{
E
}
, M:
{
M
}
, K:
{
K
}
, type:
{
input_dtype
}
OK"
)
NUM_RANKS
=
48
M_PER_RANKs
=
[
128
,
256
,
512
,
1024
]
Ms
=
[
M_PER_RANK
*
NUM_RANKS
for
M_PER_RANK
in
M_PER_RANKs
]
Ks
=
[
2048
,
4096
,
7168
]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"M"
,
"K"
],
x_vals
=
list
(
itertools
.
product
(
Ms
,
Ks
)),
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"triton_fp8"
,
"cuda_unfused_fp4"
,
"cuda_fused_fp4"
],
line_names
=
[
"triton_fp8"
,
"cuda_unfused_fp4"
,
"cuda_fused_fp4"
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"ms"
,
plot_name
=
"fp4 quant"
,
args
=
{},
)
)
def
benchmark
(
M
,
K
,
provider
):
E
=
6
device
=
"cuda"
x
=
torch
.
randn
(
E
,
M
,
K
,
device
=
device
,
dtype
=
torch
.
bfloat16
)
glb_scales
=
torch
.
ones
((
E
,),
dtype
=
torch
.
float32
,
device
=
device
)
masks
=
torch
.
randint
(
1
,
4096
,
(
E
,),
dtype
=
torch
.
int32
,
device
=
device
)
fp8_out
=
torch
.
empty
(
(
x
.
shape
[
0
],
x
.
shape
[
1
],
x
.
shape
[
2
]
//
2
,
),
device
=
x
.
device
,
dtype
=
torch
.
float8_e4m3fn
,
)
scale_block_size
=
128
fp8_scales
=
torch
.
empty
(
(
x
.
shape
[
0
],
x
.
shape
[
1
],
x
.
shape
[
2
]
//
2
//
scale_block_size
,
),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"triton_fp8"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
silu_and_mul_masked_post_quant_fwd
(
x
,
fp8_out
,
fp8_scales
,
scale_block_size
,
masks
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
),
quantiles
=
quantiles
,
)
if
provider
==
"cuda_unfused_fp4"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
scaled_fp4_grouped_quant
(
silu_and_mul
(
x
),
glb_scales
,
masks
,
),
quantiles
=
quantiles
,
)
if
provider
==
"cuda_fused_fp4"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
silu_and_mul_scaled_fp4_grouped_quant
(
x
,
glb_scales
,
masks
,
),
quantiles
=
quantiles
,
)
return
ms
,
min_ms
,
max_ms
def
test_accuracy
():
E
=
6
N_RANKS
=
48
Ms
=
[
128
,
256
,
512
,
1024
]
Ks
=
[
2048
,
4096
,
7168
]
input_dtype
=
torch
.
bfloat16
for
M
in
Ms
:
for
K
in
Ks
:
_test_accuracy_once
(
E
,
N_RANKS
*
M
,
K
,
input_dtype
,
"cuda"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--save_path"
,
type
=
str
,
default
=
"./bench_fp4_quant_res"
,
help
=
"Path to save fp4 quant benchmark results"
,
)
args
=
parser
.
parse_args
()
test_accuracy
()
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
args
.
save_path
)
sgl-kernel/csrc/common_extension.cc
View file @
5c34b4f1
...
@@ -159,8 +159,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -159,8 +159,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
m
.
def
(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor input, Tensor input_global_scale, Tensor mask, bool use_silu_and_mul) -> ()"
);
"Tensor output_scale_offset_by_experts, Tensor mask) -> ()"
);
m
.
impl
(
"silu_and_mul_scaled_fp4_experts_quant"
,
torch
::
kCUDA
,
&
silu_and_mul_scaled_fp4_experts_quant
);
m
.
impl
(
"silu_and_mul_scaled_fp4_experts_quant"
,
torch
::
kCUDA
,
&
silu_and_mul_scaled_fp4_experts_quant
);
m
.
def
(
m
.
def
(
...
...
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
View file @
5c34b4f1
...
@@ -347,7 +347,7 @@ cvt_fp16_to_fp4(
...
@@ -347,7 +347,7 @@ cvt_fp16_to_fp4(
}
}
}
}
// E
e
rly exit when using masks.
// E
a
rly exit when using masks.
if
(
use_mask
&&
rowIdx_in_expert
>=
mask
[
expert_idx
])
{
if
(
use_mask
&&
rowIdx_in_expert
>=
mask
[
expert_idx
])
{
continue
;
continue
;
}
}
...
@@ -383,6 +383,107 @@ cvt_fp16_to_fp4(
...
@@ -383,6 +383,107 @@ cvt_fp16_to_fp4(
#endif
#endif
}
}
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__global__
void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
512
,
4
)
cvt_fp16_to_fp4_expert
(
#else
cvt_fp16_to_fp4_expert
(
#endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
,
int32_t
*
mask
,
bool
use_silu_and_mul
,
int
n_experts
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// Input tensor row/col loops.
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
(
gridDim
.
x
*
blockDim
.
x
)
/
n_experts
;
int
remainder
=
(
gridDim
.
x
*
blockDim
.
x
)
%
n_experts
;
int
expert_idx
;
int
tid_in_expert
;
int
actual_stride
;
if
(
remainder
>
0
)
{
int
bound
=
remainder
*
(
stride
+
1
);
if
(
tid
<
bound
)
{
expert_idx
=
tid
/
(
stride
+
1
);
tid_in_expert
=
tid
%
(
stride
+
1
);
actual_stride
=
stride
+
1
;
}
else
{
expert_idx
=
remainder
+
(
tid
-
bound
)
/
stride
;
tid_in_expert
=
(
tid
-
bound
)
%
stride
;
actual_stride
=
stride
;
}
}
else
{
expert_idx
=
tid
/
stride
;
tid_in_expert
=
tid
%
stride
;
actual_stride
=
stride
;
}
int
m
=
numRows
/
n_experts
;
int
padded_m
=
(
m
+
(
128
-
1
))
/
128
*
128
;
int
colsPerRow
=
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
// TODO(kaixih@nvidia): For now, we assume mask is used together with
// silu_and_mal. Maybe we want a more general behavior of mask later. In the
// silu case, the input last dim doubles.
bool
use_mask
=
mask
!=
nullptr
;
int
actualColsPerRow
=
use_silu_and_mul
?
colsPerRow
*
2
:
colsPerRow
;
// Each global thread processes one element
for
(
int
globalIdx
=
tid_in_expert
+
expert_idx
*
m
*
colsPerRow
;
globalIdx
<
(
expert_idx
+
1
)
*
m
*
colsPerRow
;
globalIdx
+=
actual_stride
)
{
// Calculate which row and column this global thread should process
int
rowIdx
=
globalIdx
/
colsPerRow
;
int
colIdx
=
globalIdx
%
colsPerRow
;
// Find index within the experts
int
rowIdx_in_expert
=
rowIdx
-
expert_idx
*
m
;
// Early exit when using masks.
if
(
use_mask
&&
rowIdx_in_expert
>=
mask
[
expert_idx
])
{
break
;
}
int64_t
inOffset
=
rowIdx
*
actualColsPerRow
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
if
(
use_silu_and_mul
)
{
PackedVec
in_vec_mul
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
+
colsPerRow
];
silu_and_mul
(
in_vec
,
in_vec_mul
);
}
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t
outOffset
=
rowIdx
*
colsPerRow
+
colIdx
;
auto
&
out_pos
=
out
[
outOffset
];
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float
const
SFScaleVal
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
expert_idx
];
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
// The actual output_scales dim is computed from the padded numCols.
int32_t
numCols_padded
=
(
numCols
+
factor
-
1
)
/
factor
*
factor
;
int
numCols_SFout
=
numCols_padded
/
CVT_FP4_SF_VEC_SIZE
/
4
;
uint32_t
*
SFout_in_expert
=
SFout
+
expert_idx
*
padded_m
*
numCols_SFout
;
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx_in_expert
,
colIdx
,
numCols
,
SFout_in_expert
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
}
#endif
}
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
template
<
class
Type
,
bool
UE8M0_SF
=
false
,
bool
SMALL_NUM_EXPERTS
=
false
>
template
<
class
Type
,
bool
UE8M0_SF
=
false
,
bool
SMALL_NUM_EXPERTS
=
false
>
__global__
void
__global__
void
...
@@ -499,6 +600,7 @@ void quant_impl(
...
@@ -499,6 +600,7 @@ void quant_impl(
void
*
input_offset_by_experts
,
void
*
input_offset_by_experts
,
void
*
output_scale_offset_by_experts
,
void
*
output_scale_offset_by_experts
,
void
*
mask
,
void
*
mask
,
bool
use_silu_and_mul
,
int
m_topk
,
int
m_topk
,
int
k
,
int
k
,
int
n_experts
,
int
n_experts
,
...
@@ -522,6 +624,22 @@ void quant_impl(
...
@@ -522,6 +624,22 @@ void quant_impl(
block
.
x
=
(
block
.
x
+
1
)
/
2
;
block
.
x
=
(
block
.
x
+
1
)
/
2
;
}
}
// TODO(kaixih@nvidia): Should relax this to allow any grid size.
if
(
mask
!=
nullptr
)
{
grid
.
x
=
(
grid
.
x
+
n_experts
-
1
)
/
n_experts
*
n_experts
;
cvt_fp16_to_fp4_expert
<
T
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
float
*>
(
input_global_scale
),
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
int32_t
*>
(
mask
),
use_silu_and_mul
,
n_experts
);
return
;
}
int
const
blockRepeat
=
(
totalWorkSize
+
block
.
x
*
grid
.
x
-
1
)
/
(
block
.
x
*
grid
.
x
);
int
const
blockRepeat
=
(
totalWorkSize
+
block
.
x
*
grid
.
x
-
1
)
/
(
block
.
x
*
grid
.
x
);
if
(
blockRepeat
>
1
)
{
if
(
blockRepeat
>
1
)
{
size_t
shared_mem_size
=
(
n_experts
+
1
)
*
sizeof
(
uint32_t
);
size_t
shared_mem_size
=
(
n_experts
+
1
)
*
sizeof
(
uint32_t
);
...
@@ -652,6 +770,7 @@ void scaled_fp4_experts_quant_sm100a(
...
@@ -652,6 +770,7 @@ void scaled_fp4_experts_quant_sm100a(
input_offset_by_experts
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
nullptr
,
// mask
nullptr
,
// mask
false
,
// use_silu_and_mul
m_topk
,
m_topk
,
k
,
k
,
n_experts
,
n_experts
,
...
@@ -665,6 +784,7 @@ void scaled_fp4_experts_quant_sm100a(
...
@@ -665,6 +784,7 @@ void scaled_fp4_experts_quant_sm100a(
input_offset_by_experts
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
nullptr
,
// mask
nullptr
,
// mask
false
,
// use_silu_and_mul
m_topk
,
m_topk
,
k
,
k
,
n_experts
,
n_experts
,
...
@@ -679,28 +799,21 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
...
@@ -679,28 +799,21 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
mask
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
,
bool
use_silu_and_mul
)
{
torch
::
Tensor
const
&
mask
)
{
CHECK_INPUT
(
output
,
"output must be a CUDA tensor"
);
CHECK_INPUT
(
output
,
"output must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale
,
"output_scale must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale
,
"output_scale must be a CUDA tensor"
);
CHECK_INPUT
(
input
,
"input must be a CUDA tensor"
);
CHECK_INPUT
(
input
,
"input must be a CUDA tensor"
);
CHECK_INPUT
(
input_global_scale
,
"input_global_scale must be a CUDA tensor"
);
CHECK_INPUT
(
input_global_scale
,
"input_global_scale must be a CUDA tensor"
);
CHECK_INPUT
(
input_offset_by_experts
,
"input_offset_by_experts must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale_offset_by_experts
,
"output_scale_offset_by_experts must be a CUDA tensor"
);
CHECK_INPUT
(
mask
,
"mask must be a CUDA tensor"
);
CHECK_INPUT
(
mask
,
"mask must be a CUDA tensor"
);
TORCH_CHECK
(
output
.
dim
()
==
2
);
TORCH_CHECK
(
output
.
dim
()
==
2
);
TORCH_CHECK
(
output_scale
.
dim
()
==
2
);
TORCH_CHECK
(
output_scale
.
dim
()
==
2
);
TORCH_CHECK
(
input
.
dim
()
==
2
);
TORCH_CHECK
(
input
.
dim
()
==
2
);
TORCH_CHECK
(
input_global_scale
.
dim
()
==
1
);
TORCH_CHECK
(
input_global_scale
.
dim
()
==
1
);
TORCH_CHECK
(
input_offset_by_experts
.
dim
()
==
1
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
dim
()
==
1
);
TORCH_CHECK
(
input
.
scalar_type
()
==
HALF
||
input
.
scalar_type
()
==
BF16
);
TORCH_CHECK
(
input
.
scalar_type
()
==
HALF
||
input
.
scalar_type
()
==
BF16
);
TORCH_CHECK
(
input_global_scale
.
scalar_type
()
==
FLOAT
);
TORCH_CHECK
(
input_global_scale
.
scalar_type
()
==
FLOAT
);
TORCH_CHECK
(
input_offset_by_experts
.
scalar_type
()
==
INT
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
scalar_type
()
==
INT
);
TORCH_CHECK
(
mask
.
scalar_type
()
==
INT
);
TORCH_CHECK
(
mask
.
scalar_type
()
==
INT
);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
// output_scale is int32 (four fp8 values are packed into one int32)
...
@@ -710,12 +823,12 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
...
@@ -710,12 +823,12 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
const
int
BLOCK_SIZE
=
16
;
const
int
BLOCK_SIZE
=
16
;
auto
m_topk
=
input
.
size
(
0
);
auto
m_topk
=
input
.
size
(
0
);
auto
k_by_2
=
input
.
size
(
1
);
auto
k_by_2
=
input
.
size
(
1
);
TORCH_CHECK
(
k_by_2
%
2
==
0
,
"k must be a multiple of 2"
);
auto
k
=
k_by_2
;
auto
k
=
k_by_2
/
2
;
if
(
use_silu_and_mul
)
{
TORCH_CHECK
(
k
%
BLOCK_SIZE
==
0
,
"k must be a multiple of 16"
);
TORCH_CHECK
(
k_by_2
%
2
==
0
,
"k must be a multiple of 2"
);
k
=
k_by_2
/
2
;
}
auto
n_experts
=
input_global_scale
.
size
(
0
);
auto
n_experts
=
input_global_scale
.
size
(
0
);
TORCH_CHECK
(
input_offset_by_experts
.
size
(
0
)
==
n_experts
+
1
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
size
(
0
)
==
n_experts
+
1
);
TORCH_CHECK
(
mask
.
size
(
0
)
==
n_experts
);
TORCH_CHECK
(
mask
.
size
(
0
)
==
n_experts
);
TORCH_CHECK
(
output
.
size
(
0
)
==
m_topk
);
TORCH_CHECK
(
output
.
size
(
0
)
==
m_topk
);
TORCH_CHECK
(
output
.
size
(
1
)
==
k
/
2
);
TORCH_CHECK
(
output
.
size
(
1
)
==
k
/
2
);
...
@@ -734,9 +847,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
...
@@ -734,9 +847,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
output_scale
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
nullptr
,
//
input_offset_by_experts
output_scale_offset_by_experts
.
data_ptr
(),
nullptr
,
//
output_scale_offset_by_experts
mask
.
data_ptr
(),
mask
.
data_ptr
(),
use_silu_and_mul
,
m_topk
,
m_topk
,
k
,
k
,
n_experts
,
n_experts
,
...
@@ -747,9 +861,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
...
@@ -747,9 +861,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
output_scale
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
nullptr
,
//
input_offset_by_experts
output_scale_offset_by_experts
.
data_ptr
(),
nullptr
,
//
output_scale_offset_by_experts
mask
.
data_ptr
(),
mask
.
data_ptr
(),
use_silu_and_mul
,
m_topk
,
m_topk
,
k
,
k
,
n_experts
,
n_experts
,
...
...
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
View file @
5c34b4f1
...
@@ -32,9 +32,8 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
...
@@ -32,9 +32,8 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
mask
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
,
bool
use_silu_and_mul
);
torch
::
Tensor
const
&
mask
);
#endif
#endif
...
@@ -65,12 +64,11 @@ void silu_and_mul_scaled_fp4_experts_quant(
...
@@ -65,12 +64,11 @@ void silu_and_mul_scaled_fp4_experts_quant(
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
mask
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
,
bool
use_silu_and_mul
)
{
torch
::
Tensor
const
&
mask
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
silu_and_mul_scaled_fp4_experts_quant_sm100a
(
return
silu_and_mul_scaled_fp4_experts_quant_sm100a
(
output
,
output_scale
,
input
,
input_global_scale
,
input_offset_by_experts
,
output_scale_offset_by_experts
,
mask
);
output
,
output_scale
,
input
,
input_global_scale
,
mask
,
use_silu_and_mul
);
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 experts quantization kernel"
);
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 experts quantization kernel"
);
}
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
5c34b4f1
...
@@ -394,9 +394,8 @@ void silu_and_mul_scaled_fp4_experts_quant(
...
@@ -394,9 +394,8 @@ void silu_and_mul_scaled_fp4_experts_quant(
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
mask
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
,
bool
use_silu_and_mul
);
torch
::
Tensor
const
&
mask
);
/*
/*
* From csrc/moe/cutlass_moe/w4a8
* From csrc/moe/cutlass_moe/w4a8
*/
*/
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
5c34b4f1
...
@@ -298,6 +298,7 @@ def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
...
@@ -298,6 +298,7 @@ def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
def
scaled_fp4_grouped_quant
(
def
scaled_fp4_grouped_quant
(
input_tensor
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
):
):
"""
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
Quantize input tensor to FP4 and return quantized tensor and scale, for
...
@@ -331,22 +332,14 @@ def scaled_fp4_grouped_quant(
...
@@ -331,22 +332,14 @@ def scaled_fp4_grouped_quant(
output_scales
=
torch
.
empty
(
output_scales
=
torch
.
empty
(
l
,
padded_m
,
padded_k_int32
,
device
=
device
,
dtype
=
torch
.
int32
l
,
padded_m
,
padded_k_int32
,
device
=
device
,
dtype
=
torch
.
int32
)
)
input_offsets
=
torch
.
arange
(
0
,
(
l
+
1
)
*
m
,
step
=
m
,
dtype
=
torch
.
int
,
device
=
device
)
output_offsets
=
torch
.
arange
(
0
,
(
l
+
1
)
*
padded_m
,
step
=
padded_m
,
dtype
=
torch
.
int
,
device
=
device
,
)
torch
.
ops
.
sgl_kernel
.
scaled_fp4_experts_quant
.
default
(
torch
.
ops
.
sgl_kernel
.
silu_and_mul_
scaled_fp4_experts_quant
.
default
(
output
.
view
(
l
*
m
,
k
//
2
),
output
.
view
(
l
*
m
,
k
//
2
),
output_scales
.
view
(
l
*
padded_m
,
padded_k_int32
),
output_scales
.
view
(
l
*
padded_m
,
padded_k_int32
),
input_tensor
.
view
(
l
*
m
,
k
),
input_tensor
.
view
(
l
*
m
,
k
),
input_global_scale
,
input_global_scale
,
input_offsets
,
mask
,
output_offsets
,
use_silu_and_mul
=
False
,
)
)
# The physical layout of the output is (l, m, k // 2), but we want to return a
# The physical layout of the output is (l, m, k // 2), but we want to return a
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
...
@@ -400,23 +393,14 @@ def silu_and_mul_scaled_fp4_grouped_quant(
...
@@ -400,23 +393,14 @@ def silu_and_mul_scaled_fp4_grouped_quant(
output_scales
=
torch
.
empty
(
output_scales
=
torch
.
empty
(
l
,
padded_m
,
padded_k_int32
,
device
=
device
,
dtype
=
torch
.
int32
l
,
padded_m
,
padded_k_int32
,
device
=
device
,
dtype
=
torch
.
int32
)
)
input_offsets
=
torch
.
arange
(
0
,
(
l
+
1
)
*
m
,
step
=
m
,
dtype
=
torch
.
int
,
device
=
device
)
output_offsets
=
torch
.
arange
(
0
,
(
l
+
1
)
*
padded_m
,
step
=
padded_m
,
dtype
=
torch
.
int
,
device
=
device
,
)
torch
.
ops
.
sgl_kernel
.
silu_and_mul_scaled_fp4_experts_quant
.
default
(
torch
.
ops
.
sgl_kernel
.
silu_and_mul_scaled_fp4_experts_quant
.
default
(
output
.
view
(
l
*
m
,
k
//
2
),
output
.
view
(
l
*
m
,
k
//
2
),
output_scales
.
view
(
l
*
padded_m
,
padded_k_int32
),
output_scales
.
view
(
l
*
padded_m
,
padded_k_int32
),
input_tensor
.
view
(
l
*
m
,
k_by_2
),
input_tensor
.
view
(
l
*
m
,
k_by_2
),
input_global_scale
,
input_global_scale
,
input_offsets
,
output_offsets
,
mask
,
mask
,
use_silu_and_mul
=
True
,
)
)
# The physical layout of the output is (l, m, k // 2), but we want to return a
# The physical layout of the output is (l, m, k // 2), but we want to return a
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
...
...
sgl-kernel/tests/test_fp4_quantize.py
View file @
5c34b4f1
...
@@ -174,17 +174,22 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
...
@@ -174,17 +174,22 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
skip_condition
,
reason
=
"Nvfp4 Requires compute capability of 10 or above."
skip_condition
,
reason
=
"Nvfp4 Requires compute capability of 10 or above."
)
)
def
test_quantize_to_fp4_grouped
():
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
2
,
512
,
2048
),
(
2
,
100
,
128
),
(
2
,
128
,
96
)])
def
test_quantize_to_fp4_grouped
(
shape
):
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
torch
.
set_default_device
(
"cuda:0"
)
torch
.
set_default_device
(
"cuda:0"
)
l
,
m
,
k
=
2
,
512
,
2048
l
,
m
,
k
=
shape
x
=
torch
.
randn
((
l
,
m
,
k
),
dtype
=
torch
.
bfloat16
)
x
=
torch
.
randn
((
l
,
m
,
k
),
dtype
=
torch
.
bfloat16
)
max_m
=
m
//
2
assert
max_m
<=
m
mask
=
torch
.
randint
(
1
,
max_m
,
(
l
,),
dtype
=
torch
.
int32
)
tensor_amax
=
x
.
abs
().
amax
(
dim
=
(
1
,
2
)).
to
(
torch
.
float32
)
tensor_amax
=
x
.
abs
().
amax
(
dim
=
(
1
,
2
)).
to
(
torch
.
float32
)
x_sf_global
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
tensor_amax
x_sf_global
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
tensor_amax
output
,
output_scales
=
scaled_fp4_grouped_quant
(
output
,
output_scales
=
scaled_fp4_grouped_quant
(
x
,
x
,
x_sf_global
,
x_sf_global
,
mask
,
)
)
# output in logical (m, k, l), but its physical layout is (l, m, k).
# output in logical (m, k, l), but its physical layout is (l, m, k).
# So permute first to (l, m, k).
# So permute first to (l, m, k).
...
@@ -195,23 +200,25 @@ def test_quantize_to_fp4_grouped():
...
@@ -195,23 +200,25 @@ def test_quantize_to_fp4_grouped():
output_scales
=
output_scales
.
permute
(
5
,
2
,
4
,
0
,
1
,
3
).
view
(
l
,
padded_m
,
-
1
)
output_scales
=
output_scales
.
permute
(
5
,
2
,
4
,
0
,
1
,
3
).
view
(
l
,
padded_m
,
-
1
)
for
i
in
range
(
l
):
for
i
in
range
(
l
):
a_fp4
,
a_scale_interleaved
=
scaled_fp4_quant
(
x
[
i
],
x_sf_global
[
i
])
a_fp4
,
a_scale_interleaved
=
scaled_fp4_quant
(
x
[
i
],
x_sf_global
[
i
])
torch
.
testing
.
assert_close
(
a_fp4
,
output
[
i
])
torch
.
testing
.
assert_close
(
a_fp4
[:
mask
[
i
]],
output
[
i
][:
mask
[
i
]])
torch
.
testing
.
assert_close
(
# Recover swizzled scales to linear layout and drop padded values, so
a_scale_interleaved
.
to
(
torch
.
float
),
output_scales
[
i
].
to
(
torch
.
float
)
# no extra checks on padding are needed.
)
scale_ref
=
recover_swizzled_scales
(
a_scale_interleaved
,
m
,
k
)
scale_ans
=
recover_swizzled_scales
(
output_scales
[
i
],
m
,
k
)
torch
.
testing
.
assert_close
(
scale_ref
[:
mask
[
i
]],
scale_ans
[:
mask
[
i
]])
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
skip_condition
,
reason
=
"Nvfp4 Requires compute capability of 10 or above."
skip_condition
,
reason
=
"Nvfp4 Requires compute capability of 10 or above."
)
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
32
,
100
,
2048
),
(
32
,
512
,
2048
)])
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
32
,
100
,
2048
),
(
32
,
512
,
2048
),
(
6
,
6144
,
2048
)])
def
test_silu_and_mul_quantize_to_fp4_grouped
(
shape
:
tuple
[
int
,
int
])
->
None
:
def
test_silu_and_mul_quantize_to_fp4_grouped
(
shape
)
:
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
torch
.
set_default_device
(
"cuda:0"
)
torch
.
set_default_device
(
"cuda:0"
)
l
,
m
,
k
=
shape
l
,
m
,
k
=
shape
x
=
torch
.
randn
((
l
,
m
,
k
*
2
),
dtype
=
torch
.
bfloat16
)
x
=
torch
.
randn
((
l
,
m
,
k
*
2
),
dtype
=
torch
.
bfloat16
)
max_m
=
8
max_m
=
m
//
2
assert
max_m
<=
m
assert
max_m
<=
m
mask
=
torch
.
randint
(
1
,
max_m
,
(
l
,),
dtype
=
torch
.
int32
)
mask
=
torch
.
randint
(
1
,
max_m
,
(
l
,),
dtype
=
torch
.
int32
)
...
@@ -221,6 +228,7 @@ def test_silu_and_mul_quantize_to_fp4_grouped(shape: tuple[int, int]) -> None:
...
@@ -221,6 +228,7 @@ def test_silu_and_mul_quantize_to_fp4_grouped(shape: tuple[int, int]) -> None:
ref_output
,
ref_output_scales
=
scaled_fp4_grouped_quant
(
ref_output
,
ref_output_scales
=
scaled_fp4_grouped_quant
(
ref_y
,
ref_y
,
y_sf_global
,
y_sf_global
,
mask
,
)
)
output
,
output_scales
=
silu_and_mul_scaled_fp4_grouped_quant
(
output
,
output_scales
=
silu_and_mul_scaled_fp4_grouped_quant
(
x
,
x
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment