Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e5638573
Unverified
Commit
e5638573
authored
Aug 22, 2025
by
Kaixi Hou
Committed by
GitHub
Aug 22, 2025
Browse files
[NVIDA] [1/N] Nvfp4 Masked Gemm: Add quant op for the flashinfer grouped gemm (#9200)
parent
f556ac8b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
420 additions
and
13 deletions
+420
-13
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+6
-0
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
+160
-12
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
+24
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+8
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+2
-0
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+136
-0
sgl-kernel/tests/test_fp4_quantize.py
sgl-kernel/tests/test_fp4_quantize.py
+84
-1
No files found.
sgl-kernel/csrc/common_extension.cc
View file @
e5638573
...
@@ -157,6 +157,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -157,6 +157,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor output_scale_offset_by_experts) -> ()"
);
"Tensor output_scale_offset_by_experts) -> ()"
);
m
.
impl
(
"scaled_fp4_experts_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_experts_quant
);
m
.
impl
(
"scaled_fp4_experts_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_experts_quant
);
m
.
def
(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts, Tensor mask) -> ()"
);
m
.
impl
(
"silu_and_mul_scaled_fp4_experts_quant"
,
torch
::
kCUDA
,
&
silu_and_mul_scaled_fp4_experts_quant
);
m
.
def
(
m
.
def
(
"cutlass_fp4_group_mm(Tensor! output, Tensor a, Tensor b,"
"cutlass_fp4_group_mm(Tensor! output, Tensor a, Tensor b,"
"Tensor a_blockscale, Tensor b_blockscale, Tensor alphas,"
"Tensor a_blockscale, Tensor b_blockscale, Tensor alphas,"
...
...
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
View file @
e5638573
...
@@ -239,6 +239,33 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
...
@@ -239,6 +239,33 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
#endif
#endif
}
}
__device__
__forceinline__
float
silu
(
const
float
&
val
)
{
return
val
/
(
1.0
f
+
__expf
(
-
val
));
}
template
<
class
Type
>
inline
__device__
void
silu_and_mul
(
PackedVec
<
Type
>&
x_vec
,
const
PackedVec
<
Type
>&
y_vec
)
{
float2
x
[
CVT_FP4_ELTS_PER_THREAD
/
2
];
float2
y
[
CVT_FP4_ELTS_PER_THREAD
/
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
half
>
)
{
x
[
i
]
=
__half22float2
(
x_vec
.
elts
[
i
]);
y
[
i
]
=
__half22float2
(
y_vec
.
elts
[
i
]);
x
[
i
].
x
=
silu
(
x
[
i
].
x
)
*
y
[
i
].
x
;
x
[
i
].
y
=
silu
(
x
[
i
].
y
)
*
y
[
i
].
y
;
x_vec
.
elts
[
i
]
=
__float22half2_rn
(
x
[
i
]);
}
else
{
x
[
i
]
=
__bfloat1622float2
(
x_vec
.
elts
[
i
]);
y
[
i
]
=
__bfloat1622float2
(
y_vec
.
elts
[
i
]);
x
[
i
].
x
=
silu
(
x
[
i
].
x
)
*
y
[
i
].
x
;
x
[
i
].
y
=
silu
(
x
[
i
].
y
)
*
y
[
i
].
y
;
x_vec
.
elts
[
i
]
=
__float22bfloat162_rn
(
x
[
i
]);
}
}
}
// Use UE4M3 by default.
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
,
bool
SMALL_NUM_EXPERTS
=
false
>
template
<
class
Type
,
bool
UE8M0_SF
=
false
,
bool
SMALL_NUM_EXPERTS
=
false
>
__global__
void
__global__
void
...
@@ -255,6 +282,7 @@ cvt_fp16_to_fp4(
...
@@ -255,6 +282,7 @@ cvt_fp16_to_fp4(
uint32_t
*
SFout
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int32_t
*
mask
,
int
n_experts
,
int
n_experts
,
bool
low_latency
)
{
bool
low_latency
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
...
@@ -265,6 +293,11 @@ cvt_fp16_to_fp4(
...
@@ -265,6 +293,11 @@ cvt_fp16_to_fp4(
// Input tensor row/col loops.
// Input tensor row/col loops.
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
colsPerRow
=
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
int
colsPerRow
=
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
// TODO(kaixih@nvidia): For now, we assume mask is used together with
// silu_and_mal. Maybe we want a more general behavior of mask later. In the
// silu case, the input last dim doubles.
bool
use_mask
=
mask
!=
nullptr
;
int
actualColsPerRow
=
use_mask
?
colsPerRow
*
2
:
colsPerRow
;
// Each global thread processes one element
// Each global thread processes one element
for
(
int
globalIdx
=
tid
;
globalIdx
<
numRows
*
colsPerRow
;
globalIdx
+=
gridDim
.
x
*
blockDim
.
x
)
{
for
(
int
globalIdx
=
tid
;
globalIdx
<
numRows
*
colsPerRow
;
globalIdx
+=
gridDim
.
x
*
blockDim
.
x
)
{
...
@@ -272,13 +305,6 @@ cvt_fp16_to_fp4(
...
@@ -272,13 +305,6 @@ cvt_fp16_to_fp4(
int
rowIdx
=
globalIdx
/
colsPerRow
;
int
rowIdx
=
globalIdx
/
colsPerRow
;
int
colIdx
=
globalIdx
%
colsPerRow
;
int
colIdx
=
globalIdx
%
colsPerRow
;
int64_t
inOffset
=
rowIdx
*
colsPerRow
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t
outOffset
=
inOffset
;
auto
&
out_pos
=
out
[
outOffset
];
// Find index within the experts using different strategies based on expert
// Find index within the experts using different strategies based on expert
// count
// count
int
rowIdx_in_expert
=
0
;
int
rowIdx_in_expert
=
0
;
...
@@ -321,6 +347,23 @@ cvt_fp16_to_fp4(
...
@@ -321,6 +347,23 @@ cvt_fp16_to_fp4(
}
}
}
}
// Eerly exit when using masks.
if
(
use_mask
&&
rowIdx_in_expert
>=
mask
[
expert_idx
])
{
continue
;
}
int64_t
inOffset
=
rowIdx
*
actualColsPerRow
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
if
(
use_mask
)
{
PackedVec
in_vec_mul
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
+
colsPerRow
];
silu_and_mul
(
in_vec
,
in_vec_mul
);
}
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t
outOffset
=
rowIdx
*
colsPerRow
+
colIdx
;
auto
&
out_pos
=
out
[
outOffset
];
// Get the global scaling factor, which will be applied to the SF.
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
// (448.f / (Alpha_A / 6.f)).
...
@@ -356,6 +399,7 @@ cvt_fp16_to_fp4(
...
@@ -356,6 +399,7 @@ cvt_fp16_to_fp4(
uint32_t
*
SFout
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int32_t
*
mask
,
int
n_experts
)
{
int
n_experts
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
PackedVec
=
PackedVec
<
Type
>
;
using
PackedVec
=
PackedVec
<
Type
>
;
...
@@ -383,6 +427,8 @@ cvt_fp16_to_fp4(
...
@@ -383,6 +427,8 @@ cvt_fp16_to_fp4(
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
colsPerRow
=
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
int
colsPerRow
=
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
bool
use_mask
=
mask
!=
nullptr
;
int
actualColsPerRow
=
use_mask
?
colsPerRow
*
2
:
colsPerRow
;
// Each global thread processes one element
// Each global thread processes one element
for
(
int
globalIdx
=
tid
;
globalIdx
<
numRows
*
colsPerRow
;
globalIdx
+=
gridDim
.
x
*
blockDim
.
x
)
{
for
(
int
globalIdx
=
tid
;
globalIdx
<
numRows
*
colsPerRow
;
globalIdx
+=
gridDim
.
x
*
blockDim
.
x
)
{
...
@@ -390,11 +436,6 @@ cvt_fp16_to_fp4(
...
@@ -390,11 +436,6 @@ cvt_fp16_to_fp4(
int
rowIdx
=
globalIdx
/
colsPerRow
;
int
rowIdx
=
globalIdx
/
colsPerRow
;
int
colIdx
=
globalIdx
%
colsPerRow
;
int
colIdx
=
globalIdx
%
colsPerRow
;
int64_t
inOffset
=
rowIdx
*
colsPerRow
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
int64_t
outOffset
=
inOffset
;
auto
&
out_pos
=
out
[
outOffset
];
// Find expert using binary search for better performance with large m_topk
// Find expert using binary search for better performance with large m_topk
int
rowIdx_in_expert
=
0
;
int
rowIdx_in_expert
=
0
;
int
expert_idx
=
0
;
int
expert_idx
=
0
;
...
@@ -419,6 +460,21 @@ cvt_fp16_to_fp4(
...
@@ -419,6 +460,21 @@ cvt_fp16_to_fp4(
}
}
}
}
if
(
use_mask
&&
rowIdx_in_expert
>=
mask
[
expert_idx
])
{
continue
;
}
int64_t
inOffset
=
rowIdx
*
actualColsPerRow
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
if
(
use_mask
)
{
PackedVec
in_vec_mul
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
+
colsPerRow
];
silu_and_mul
(
in_vec
,
in_vec_mul
);
}
int64_t
outOffset
=
rowIdx
*
colsPerRow
+
colIdx
;
auto
&
out_pos
=
out
[
outOffset
];
float
const
SFScaleVal
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
expert_idx
];
float
const
SFScaleVal
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
expert_idx
];
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
...
@@ -442,6 +498,7 @@ void quant_impl(
...
@@ -442,6 +498,7 @@ void quant_impl(
void
*
input_global_scale
,
void
*
input_global_scale
,
void
*
input_offset_by_experts
,
void
*
input_offset_by_experts
,
void
*
output_scale_offset_by_experts
,
void
*
output_scale_offset_by_experts
,
void
*
mask
,
int
m_topk
,
int
m_topk
,
int
k
,
int
k
,
int
n_experts
,
int
n_experts
,
...
@@ -478,6 +535,7 @@ void quant_impl(
...
@@ -478,6 +535,7 @@ void quant_impl(
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
reinterpret_cast
<
int32_t
*>
(
mask
),
n_experts
);
n_experts
);
}
else
{
}
else
{
cvt_fp16_to_fp4
<
T
,
false
,
true
><<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
cvt_fp16_to_fp4
<
T
,
false
,
true
><<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
...
@@ -489,6 +547,7 @@ void quant_impl(
...
@@ -489,6 +547,7 @@ void quant_impl(
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
reinterpret_cast
<
int32_t
*>
(
mask
),
n_experts
);
n_experts
);
}
}
}
else
{
}
else
{
...
@@ -502,6 +561,7 @@ void quant_impl(
...
@@ -502,6 +561,7 @@ void quant_impl(
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
reinterpret_cast
<
int32_t
*>
(
mask
),
n_experts
,
n_experts
,
/* bool low_latency */
true
);
/* bool low_latency */
true
);
}
else
{
}
else
{
...
@@ -514,6 +574,7 @@ void quant_impl(
...
@@ -514,6 +574,7 @@ void quant_impl(
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
reinterpret_cast
<
int32_t
*>
(
mask
),
n_experts
,
n_experts
,
/* bool low_latency */
true
);
/* bool low_latency */
true
);
}
}
...
@@ -590,6 +651,92 @@ void scaled_fp4_experts_quant_sm100a(
...
@@ -590,6 +651,92 @@ void scaled_fp4_experts_quant_sm100a(
input_global_scale
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
nullptr
,
// mask
m_topk
,
k
,
n_experts
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
quant_impl
<
__nv_bfloat16
>
(
output
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
nullptr
,
// mask
m_topk
,
k
,
n_experts
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Expected input data type to be half or bfloat16"
);
}
}
void
silu_and_mul_scaled_fp4_experts_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
,
torch
::
Tensor
const
&
mask
)
{
CHECK_INPUT
(
output
,
"output must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale
,
"output_scale must be a CUDA tensor"
);
CHECK_INPUT
(
input
,
"input must be a CUDA tensor"
);
CHECK_INPUT
(
input_global_scale
,
"input_global_scale must be a CUDA tensor"
);
CHECK_INPUT
(
input_offset_by_experts
,
"input_offset_by_experts must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale_offset_by_experts
,
"output_scale_offset_by_experts must be a CUDA tensor"
);
CHECK_INPUT
(
mask
,
"mask must be a CUDA tensor"
);
TORCH_CHECK
(
output
.
dim
()
==
2
);
TORCH_CHECK
(
output_scale
.
dim
()
==
2
);
TORCH_CHECK
(
input
.
dim
()
==
2
);
TORCH_CHECK
(
input_global_scale
.
dim
()
==
1
);
TORCH_CHECK
(
input_offset_by_experts
.
dim
()
==
1
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
dim
()
==
1
);
TORCH_CHECK
(
input
.
scalar_type
()
==
HALF
||
input
.
scalar_type
()
==
BF16
);
TORCH_CHECK
(
input_global_scale
.
scalar_type
()
==
FLOAT
);
TORCH_CHECK
(
input_offset_by_experts
.
scalar_type
()
==
INT
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
scalar_type
()
==
INT
);
TORCH_CHECK
(
mask
.
scalar_type
()
==
INT
);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
TORCH_CHECK
(
output
.
scalar_type
()
==
UINT8
);
TORCH_CHECK
(
output_scale
.
scalar_type
()
==
INT
);
const
int
BLOCK_SIZE
=
16
;
auto
m_topk
=
input
.
size
(
0
);
auto
k_by_2
=
input
.
size
(
1
);
TORCH_CHECK
(
k_by_2
%
2
==
0
,
"k must be a multiple of 2"
);
auto
k
=
k_by_2
/
2
;
TORCH_CHECK
(
k
%
BLOCK_SIZE
==
0
,
"k must be a multiple of 16"
);
auto
n_experts
=
input_global_scale
.
size
(
0
);
TORCH_CHECK
(
input_offset_by_experts
.
size
(
0
)
==
n_experts
+
1
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
size
(
0
)
==
n_experts
+
1
);
TORCH_CHECK
(
mask
.
size
(
0
)
==
n_experts
);
TORCH_CHECK
(
output
.
size
(
0
)
==
m_topk
);
TORCH_CHECK
(
output
.
size
(
1
)
==
k
/
2
);
int
scales_k
=
k
/
BLOCK_SIZE
;
// 4 means the swizzle requirement by nvidia nvfp4.
int
padded_k
=
(
scales_k
+
(
4
-
1
))
/
4
*
4
;
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK
(
output_scale
.
size
(
1
)
*
4
==
padded_k
);
auto
in_dtype
=
input
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
quant_impl
<
half
>
(
output
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
mask
.
data_ptr
(),
m_topk
,
m_topk
,
k
,
k
,
n_experts
,
n_experts
,
...
@@ -602,6 +749,7 @@ void scaled_fp4_experts_quant_sm100a(
...
@@ -602,6 +749,7 @@ void scaled_fp4_experts_quant_sm100a(
input_global_scale
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
mask
.
data_ptr
(),
m_topk
,
m_topk
,
k
,
k
,
n_experts
,
n_experts
,
...
...
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
View file @
e5638573
...
@@ -27,6 +27,15 @@ void scaled_fp4_experts_quant_sm100a(
...
@@ -27,6 +27,15 @@ void scaled_fp4_experts_quant_sm100a(
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
void
silu_and_mul_scaled_fp4_experts_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
,
torch
::
Tensor
const
&
mask
);
#endif
#endif
void
scaled_fp4_quant
(
void
scaled_fp4_quant
(
...
@@ -50,3 +59,18 @@ void scaled_fp4_experts_quant(
...
@@ -50,3 +59,18 @@ void scaled_fp4_experts_quant(
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 experts quantization kernel"
);
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 experts quantization kernel"
);
}
}
void
silu_and_mul_scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
,
torch
::
Tensor
const
&
mask
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
silu_and_mul_scaled_fp4_experts_quant_sm100a
(
output
,
output_scale
,
input
,
input_global_scale
,
input_offset_by_experts
,
output_scale_offset_by_experts
,
mask
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 experts quantization kernel"
);
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
e5638573
...
@@ -389,6 +389,14 @@ void scaled_fp4_experts_quant(
...
@@ -389,6 +389,14 @@ void scaled_fp4_experts_quant(
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
void
silu_and_mul_scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
,
torch
::
Tensor
const
&
mask
);
/*
/*
* From csrc/moe/cutlass_moe/w4a8
* From csrc/moe/cutlass_moe/w4a8
*/
*/
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
e5638573
...
@@ -52,12 +52,14 @@ from sgl_kernel.gemm import (
...
@@ -52,12 +52,14 @@ from sgl_kernel.gemm import (
qserve_w4a8_per_chn_gemm
,
qserve_w4a8_per_chn_gemm
,
qserve_w4a8_per_group_gemm
,
qserve_w4a8_per_group_gemm
,
scaled_fp4_experts_quant
,
scaled_fp4_experts_quant
,
scaled_fp4_grouped_quant
,
scaled_fp4_quant
,
scaled_fp4_quant
,
sgl_per_tensor_quant_fp8
,
sgl_per_tensor_quant_fp8
,
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_int8
,
sgl_per_token_group_quant_int8
,
sgl_per_token_quant_fp8
,
sgl_per_token_quant_fp8
,
shuffle_rows
,
shuffle_rows
,
silu_and_mul_scaled_fp4_grouped_quant
,
)
)
from
sgl_kernel.grammar
import
apply_token_bitmask_inplace_cuda
from
sgl_kernel.grammar
import
apply_token_bitmask_inplace_cuda
from
sgl_kernel.kvcacheio
import
(
from
sgl_kernel.kvcacheio
import
(
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
e5638573
...
@@ -295,6 +295,142 @@ def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
...
@@ -295,6 +295,142 @@ def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
return
output_tensor
return
output_tensor
def
scaled_fp4_grouped_quant
(
input_tensor
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
):
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer).
Args:
input: The input tensor to be quantized to FP4, with shape (l, m, k)
l is number of groups, m is number of tokens per group, k is number of features.
input_global_scale: A scalar scaling factor for the entire tensor, with
shape (l,).
Outputs:
output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical
layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into
an uint8.
output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l)
but the physical layout is (l, rm, rk, 32, 4, 4).
Note:
For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128.
`4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are
required by the NVIDIA Blackwell MMA operations.
"""
device
=
input_tensor
.
device
l
,
m
,
k
=
input_tensor
.
shape
sf_vec_size
=
16
assert
k
%
sf_vec_size
==
0
,
f
"k must be multiple of 16, but got
{
k
}
."
scale_k
=
k
//
sf_vec_size
padded_k
=
(
scale_k
+
(
4
-
1
))
//
4
*
4
padded_k_int32
=
padded_k
//
4
padded_m
=
(
m
+
(
128
-
1
))
//
128
*
128
output
=
torch
.
empty
(
l
,
m
,
k
//
2
,
device
=
device
,
dtype
=
torch
.
uint8
)
output_scales
=
torch
.
empty
(
l
,
padded_m
,
padded_k_int32
,
device
=
device
,
dtype
=
torch
.
int32
)
input_offsets
=
torch
.
arange
(
0
,
(
l
+
1
)
*
m
,
step
=
m
,
dtype
=
torch
.
int
,
device
=
device
)
output_offsets
=
torch
.
arange
(
0
,
(
l
+
1
)
*
padded_m
,
step
=
padded_m
,
dtype
=
torch
.
int
,
device
=
device
,
)
torch
.
ops
.
sgl_kernel
.
scaled_fp4_experts_quant
.
default
(
output
.
view
(
l
*
m
,
k
//
2
),
output_scales
.
view
(
l
*
padded_m
,
padded_k_int32
),
input_tensor
.
view
(
l
*
m
,
k
),
input_global_scale
,
input_offsets
,
output_offsets
,
)
# The physical layout of the output is (l, m, k // 2), but we want to return a
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
output
=
output
.
permute
(
1
,
2
,
0
)
# The physical layout of the output scales is already swizzled as (l, rm, rk, 32, 4, 4), a
# requirement for the flashinfer masked group gemm, where rm=m/128 and rk=k/4. The logic
# layout is (32, 4, rm, 4, rk, l).
output_scales
=
output_scales
.
view
(
torch
.
float8_e4m3fn
).
view
(
l
,
padded_m
//
128
,
padded_k
//
4
,
32
,
4
,
4
)
output_scales
=
output_scales
.
permute
(
3
,
4
,
1
,
5
,
2
,
0
)
return
output
,
output_scales
def
silu_and_mul_scaled_fp4_grouped_quant
(
input_tensor
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
):
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer).
Args:
input: The input tensor to be quantized to FP4, with shape (l, m, k * 2)
l is number of groups, m is number of tokens per group, k is number of features.
input_global_scale: A scalar scaling factor for the entire tensor, with
shape (l,).
mask: The mask tensor, with shape (l,)
Outputs:
output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical
layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into
an uint8.
output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l)
but the physical layout is (l, rm, rk, 32, 4, 4).
Note:
For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128.
`4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are
required by the NVIDIA Blackwell MMA operations.
"""
device
=
input_tensor
.
device
l
,
m
,
k_by_2
=
input_tensor
.
shape
k
=
k_by_2
//
2
sf_vec_size
=
16
assert
k
%
sf_vec_size
==
0
,
f
"k must be multiple of 16, but got
{
k
}
."
scale_k
=
k
//
sf_vec_size
padded_k
=
(
scale_k
+
(
4
-
1
))
//
4
*
4
padded_k_int32
=
padded_k
//
4
padded_m
=
(
m
+
(
128
-
1
))
//
128
*
128
output
=
torch
.
empty
(
l
,
m
,
k
//
2
,
device
=
device
,
dtype
=
torch
.
uint8
)
output_scales
=
torch
.
empty
(
l
,
padded_m
,
padded_k_int32
,
device
=
device
,
dtype
=
torch
.
int32
)
input_offsets
=
torch
.
arange
(
0
,
(
l
+
1
)
*
m
,
step
=
m
,
dtype
=
torch
.
int
,
device
=
device
)
output_offsets
=
torch
.
arange
(
0
,
(
l
+
1
)
*
padded_m
,
step
=
padded_m
,
dtype
=
torch
.
int
,
device
=
device
,
)
torch
.
ops
.
sgl_kernel
.
silu_and_mul_scaled_fp4_experts_quant
.
default
(
output
.
view
(
l
*
m
,
k
//
2
),
output_scales
.
view
(
l
*
padded_m
,
padded_k_int32
),
input_tensor
.
view
(
l
*
m
,
k_by_2
),
input_global_scale
,
input_offsets
,
output_offsets
,
mask
,
)
# The physical layout of the output is (l, m, k // 2), but we want to return a
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
output
=
output
.
permute
(
1
,
2
,
0
)
# The physical layout of the output scales is already swizzled as (l, rm, rk, 32, 4, 4), a
# requirement for the flashinfer masked group gemm, where rm=m/128 and rk=k/4. The logic
# layout is (32, 4, rm, 4, rk, l).
output_scales
=
output_scales
.
view
(
torch
.
float8_e4m3fn
).
view
(
l
,
padded_m
//
128
,
padded_k
//
4
,
32
,
4
,
4
)
output_scales
=
output_scales
.
permute
(
3
,
4
,
1
,
5
,
2
,
0
)
return
output
,
output_scales
def
scaled_fp4_experts_quant
(
def
scaled_fp4_experts_quant
(
input_tensor
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
...
...
sgl-kernel/tests/test_fp4_quantize.py
View file @
e5638573
import
pytest
import
pytest
import
torch
import
torch
from
sgl_kernel
import
scaled_fp4_quant
from
sgl_kernel
import
(
scaled_fp4_grouped_quant
,
scaled_fp4_quant
,
silu_and_mul
,
silu_and_mul_scaled_fp4_grouped_quant
,
)
skip_condition
=
torch
.
cuda
.
get_device_capability
()
<
(
10
,
0
)
skip_condition
=
torch
.
cuda
.
get_device_capability
()
<
(
10
,
0
)
...
@@ -166,5 +171,83 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
...
@@ -166,5 +171,83 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
torch
.
testing
.
assert_close
(
scale_ans
,
scale_ref
)
torch
.
testing
.
assert_close
(
scale_ans
,
scale_ref
)
@
pytest
.
mark
.
skipif
(
skip_condition
,
reason
=
"Nvfp4 Requires compute capability of 10 or above."
)
def
test_quantize_to_fp4_grouped
():
torch
.
manual_seed
(
42
)
torch
.
set_default_device
(
"cuda:0"
)
l
,
m
,
k
=
2
,
512
,
2048
x
=
torch
.
randn
((
l
,
m
,
k
),
dtype
=
torch
.
bfloat16
)
tensor_amax
=
x
.
abs
().
amax
(
dim
=
(
1
,
2
)).
to
(
torch
.
float32
)
x_sf_global
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
tensor_amax
output
,
output_scales
=
scaled_fp4_grouped_quant
(
x
,
x_sf_global
,
)
# output in logical (m, k, l), but its physical layout is (l, m, k).
# So permute first to (l, m, k).
output
=
output
.
permute
(
2
,
0
,
1
)
# output_scale in logical (32, 4, rm, 4, rk, l), but its physical layout is (l, rm, rk, 32, 4, 4).
# So permute first to (l, rm, rk, 32, 4, 4).
padded_m
=
((
m
+
128
-
1
)
//
128
)
*
128
output_scales
=
output_scales
.
permute
(
5
,
2
,
4
,
0
,
1
,
3
).
view
(
l
,
padded_m
,
-
1
)
for
i
in
range
(
l
):
a_fp4
,
a_scale_interleaved
=
scaled_fp4_quant
(
x
[
i
],
x_sf_global
[
i
])
torch
.
testing
.
assert_close
(
a_fp4
,
output
[
i
])
torch
.
testing
.
assert_close
(
a_scale_interleaved
.
to
(
torch
.
float
),
output_scales
[
i
].
to
(
torch
.
float
)
)
@
pytest
.
mark
.
skipif
(
skip_condition
,
reason
=
"Nvfp4 Requires compute capability of 10 or above."
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
32
,
100
,
2048
),
(
32
,
512
,
2048
)])
def
test_silu_and_mul_quantize_to_fp4_grouped
(
shape
:
tuple
[
int
,
int
])
->
None
:
torch
.
manual_seed
(
42
)
torch
.
set_default_device
(
"cuda:0"
)
l
,
m
,
k
=
shape
x
=
torch
.
randn
((
l
,
m
,
k
*
2
),
dtype
=
torch
.
bfloat16
)
max_m
=
8
assert
max_m
<=
m
mask
=
torch
.
randint
(
1
,
max_m
,
(
l
,),
dtype
=
torch
.
int32
)
ref_y
=
silu_and_mul
(
x
)
tensor_amax
=
ref_y
.
abs
().
amax
(
dim
=
(
1
,
2
)).
to
(
torch
.
float32
)
y_sf_global
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
tensor_amax
ref_output
,
ref_output_scales
=
scaled_fp4_grouped_quant
(
ref_y
,
y_sf_global
,
)
output
,
output_scales
=
silu_and_mul_scaled_fp4_grouped_quant
(
x
,
y_sf_global
,
mask
,
)
# output in logical (m, k, l), but its physical layout is (l, m, k).
# So permute first to (l, m, k).
output
=
output
.
permute
(
2
,
0
,
1
)
ref_output
=
ref_output
.
permute
(
2
,
0
,
1
)
# output_scale in logical (32, 4, rm, 4, rk, l), but its physical layout is (l, rm, rk, 32, 4, 4).
# So permute first to (l, rm, rk, 32, 4, 4).
padded_m
=
((
m
+
128
-
1
)
//
128
)
*
128
output_scales
=
output_scales
.
permute
(
5
,
2
,
4
,
0
,
1
,
3
).
view
(
l
,
padded_m
,
-
1
)
ref_output_scales
=
ref_output_scales
.
permute
(
5
,
2
,
4
,
0
,
1
,
3
).
view
(
l
,
padded_m
,
-
1
)
for
i
in
range
(
l
):
torch
.
testing
.
assert_close
(
ref_output
[
i
,
:
mask
[
i
]],
output
[
i
,
:
mask
[
i
]])
# We need to recover the swizzled scales to linear layout before applying mask slice.
scale_ref
=
recover_swizzled_scales
(
ref_output_scales
[
i
],
m
,
k
)
scale_ans
=
recover_swizzled_scales
(
output_scales
[
i
],
m
,
k
)
torch
.
testing
.
assert_close
(
scale_ref
[:
mask
[
i
]],
scale_ans
[:
mask
[
i
]])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment