Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
294fc1e2
Unverified
Commit
294fc1e2
authored
Jun 14, 2025
by
jiahanc
Committed by
GitHub
Jun 14, 2025
Browse files
[Hardware][NVIDIA][kernel] Fp4 MOE quant kernel optimization (#19500)
parent
2db9044a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
226 additions
and
48 deletions
+226
-48
csrc/quantization/fp4/nvfp4_experts_quant.cu
csrc/quantization/fp4/nvfp4_experts_quant.cu
+226
-48
No files found.
csrc/quantization/fp4/nvfp4_experts_quant.cu
View file @
294fc1e2
...
...
@@ -231,7 +231,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
}
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
template
<
class
Type
,
bool
UE8M0_SF
=
false
,
bool
SMALL_NUM_EXPERTS
=
false
>
__global__
void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
512
,
4
)
cvt_fp16_to_fp4
(
...
...
@@ -240,7 +240,7 @@ cvt_fp16_to_fp4(
#endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
)
{
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
,
bool
low_latency
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
...
...
@@ -248,28 +248,68 @@ cvt_fp16_to_fp4(
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// Input tensor row/col loops.
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
numRows
;
rowIdx
+=
gridDim
.
x
)
{
for
(
int
colIdx
=
threadIdx
.
x
;
colIdx
<
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
colIdx
+=
blockDim
.
x
)
{
int64_t
inOffset
=
rowIdx
*
(
numCols
/
CVT_FP4_ELTS_PER_THREAD
)
+
colIdx
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
colsPerRow
=
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
// Each global thread processes one element
for
(
int
globalIdx
=
tid
;
globalIdx
<
numRows
*
colsPerRow
;
globalIdx
+=
gridDim
.
x
*
blockDim
.
x
)
{
// Calculate which row and column this global thread should process
int
rowIdx
=
globalIdx
/
colsPerRow
;
int
colIdx
=
globalIdx
%
colsPerRow
;
int64_t
inOffset
=
rowIdx
*
colsPerRow
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t
outOffset
=
inOffset
;
auto
&
out_pos
=
out
[
outOffset
];
// Find index within the experts.
// Find index within the experts using different strategies based on expert
// count
int
rowIdx_in_expert
=
0
;
int
expert_idx
=
0
;
if
constexpr
(
SMALL_NUM_EXPERTS
)
{
for
(
int
i
=
0
;
i
<
n_experts
;
i
++
)
{
if
(
rowIdx
>=
input_offset_by_experts
[
i
]
&&
rowIdx
<
input_offset_by_experts
[
i
+
1
])
{
rowIdx_in_expert
=
rowIdx
-
input_offset_by_experts
[
i
];
uint32_t
current_offset
=
__ldca
(
&
input_offset_by_experts
[
i
]);
uint32_t
next_offset
=
__ldca
(
&
input_offset_by_experts
[
i
+
1
]);
if
(
rowIdx
>=
current_offset
&&
rowIdx
<
next_offset
)
{
rowIdx_in_expert
=
rowIdx
-
current_offset
;
expert_idx
=
i
;
break
;
}
}
}
else
{
// Load input offsets into registers first, then do the computation.
// Local array size set to 17 because of register limit.
uint32_t
local_offsets
[
17
];
for
(
int
chunk_start
=
0
;
chunk_start
<
n_experts
;
chunk_start
+=
16
)
{
*
reinterpret_cast
<
int4
*>
(
local_offsets
)
=
__ldca
(
reinterpret_cast
<
const
int4
*>
(
&
input_offset_by_experts
[
chunk_start
]));
*
reinterpret_cast
<
int4
*>
(
local_offsets
+
4
)
=
__ldca
(
reinterpret_cast
<
const
int4
*>
(
&
input_offset_by_experts
[
chunk_start
+
4
]));
*
reinterpret_cast
<
int4
*>
(
local_offsets
+
8
)
=
__ldca
(
reinterpret_cast
<
const
int4
*>
(
&
input_offset_by_experts
[
chunk_start
+
8
]));
*
reinterpret_cast
<
int4
*>
(
local_offsets
+
12
)
=
__ldca
(
reinterpret_cast
<
const
int4
*>
(
&
input_offset_by_experts
[
chunk_start
+
12
]));
local_offsets
[
16
]
=
__ldca
(
&
input_offset_by_experts
[
chunk_start
+
16
]);
// Check against the 16 loaded offsets
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
if
(
rowIdx
>=
local_offsets
[
i
]
&&
rowIdx
<
local_offsets
[
i
+
1
])
{
rowIdx_in_expert
=
rowIdx
-
local_offsets
[
i
];
expert_idx
=
chunk_start
+
i
;
break
;
}
}
}
}
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
...
...
@@ -288,10 +328,103 @@ cvt_fp16_to_fp4(
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx_in_expert
,
colIdx
,
numCols
,
SFout_in_expert
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
}
#endif
}
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
template
<
class
Type
,
bool
UE8M0_SF
=
false
,
bool
SMALL_NUM_EXPERTS
=
false
>
__global__
void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
1024
,
4
)
cvt_fp16_to_fp4
(
#else
cvt_fp16_to_fp4
(
#endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
extern
__shared__
uint32_t
shared_input_offsets
[];
// Load input offsets into shared memory.
// If n_experts is larger than 4, use vectorized int4 to save instructions.
// If n_experts is smaller than 4, read directly.
if
constexpr
(
SMALL_NUM_EXPERTS
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
n_experts
+
1
;
i
+=
blockDim
.
x
)
{
shared_input_offsets
[
i
]
=
input_offset_by_experts
[
i
];
}
}
else
{
for
(
int
i
=
threadIdx
.
x
*
4
;
i
<
n_experts
;
i
+=
blockDim
.
x
*
4
)
{
*
reinterpret_cast
<
int4
*>
(
&
shared_input_offsets
[
i
])
=
*
reinterpret_cast
<
const
int4
*>
(
&
input_offset_by_experts
[
i
]);
}
if
(
threadIdx
.
x
==
0
)
{
shared_input_offsets
[
n_experts
]
=
input_offset_by_experts
[
n_experts
];
}
}
__syncthreads
();
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
colsPerRow
=
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
// Each global thread processes one element
for
(
int
globalIdx
=
tid
;
globalIdx
<
numRows
*
colsPerRow
;
globalIdx
+=
gridDim
.
x
*
blockDim
.
x
)
{
// Calculate which row and column this global thread should process
int
rowIdx
=
globalIdx
/
colsPerRow
;
int
colIdx
=
globalIdx
%
colsPerRow
;
int64_t
inOffset
=
rowIdx
*
colsPerRow
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
int64_t
outOffset
=
inOffset
;
auto
&
out_pos
=
out
[
outOffset
];
// Find expert using binary search for better performance with large m_topk
int
rowIdx_in_expert
=
0
;
int
expert_idx
=
0
;
// Binary search through experts using shared memory
int
left
=
0
,
right
=
n_experts
-
1
;
while
(
left
<=
right
)
{
int
mid
=
(
left
+
right
)
/
2
;
// Get offsets: shared_input_offsets[i] corresponds to
// input_offset_by_experts[i]
uint32_t
mid_offset
=
shared_input_offsets
[
mid
];
uint32_t
next_offset
=
shared_input_offsets
[
mid
+
1
];
if
(
rowIdx
>=
mid_offset
&&
rowIdx
<
next_offset
)
{
rowIdx_in_expert
=
rowIdx
-
mid_offset
;
expert_idx
=
mid
;
break
;
}
else
if
(
rowIdx
<
mid_offset
)
{
right
=
mid
-
1
;
}
else
{
left
=
mid
+
1
;
}
}
float
const
SFScaleVal
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
expert_idx
];
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numCols_padded
=
(
numCols
+
factor
-
1
)
/
factor
*
factor
;
int
numCols_SFout
=
numCols_padded
/
CVT_FP4_SF_VEC_SIZE
/
4
;
uint32_t
*
SFout_in_expert
=
SFout
+
output_scale_offset_by_experts
[
expert_idx
]
*
numCols_SFout
;
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx_in_expert
,
colIdx
,
numCols
,
SFout_in_expert
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
}
#endif
}
...
...
@@ -309,18 +442,63 @@ void quant_impl(void* output, void* output_scale, void* input,
// Grid, Block size.
// Each thread converts 8 values.
dim3
block
(
std
::
min
(
int
(
k
/
ELTS_PER_THREAD
),
512
));
int
const
workSizePerRow
=
k
/
ELTS_PER_THREAD
;
int
const
totalWorkSize
=
m_topk
*
workSizePerRow
;
dim3
block
(
std
::
min
(
workSizePerRow
,
512
));
// Get number of blocks per SM (assume we can fully utilize the SM).
int
const
numBlocksPerSM
=
2048
/
block
.
x
;
dim3
grid
(
std
::
min
(
int
(
m_topk
),
multiProcessorCount
*
numBlocksPerSM
));
dim3
grid
(
std
::
min
(
static_cast
<
int
>
((
totalWorkSize
+
block
.
x
-
1
)
/
block
.
x
),
multiProcessorCount
*
numBlocksPerSM
));
while
(
grid
.
x
<=
multiProcessorCount
&&
block
.
x
>
64
)
{
grid
.
x
*=
2
;
block
.
x
=
(
block
.
x
+
1
)
/
2
;
}
cvt_fp16_to_fp4
<
T
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
int
const
blockRepeat
=
(
totalWorkSize
+
block
.
x
*
grid
.
x
-
1
)
/
(
block
.
x
*
grid
.
x
);
if
(
blockRepeat
>
1
)
{
size_t
shared_mem_size
=
(
n_experts
+
1
)
*
sizeof
(
uint32_t
);
if
(
n_experts
>=
4
)
{
cvt_fp16_to_fp4
<
T
,
false
,
false
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
float
*>
(
input_global_scale
),
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
);
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
);
}
else
{
cvt_fp16_to_fp4
<
T
,
false
,
true
><<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
float
*>
(
input_global_scale
),
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
);
}
}
else
{
if
(
n_experts
>=
16
)
{
cvt_fp16_to_fp4
<
T
,
false
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
float
*>
(
input_global_scale
),
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
,
/* bool low_latency */
true
);
}
else
{
cvt_fp16_to_fp4
<
T
,
false
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
float
*>
(
input_global_scale
),
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
,
/* bool low_latency */
true
);
}
}
}
/*Quantization entry for fp4 experts quantization*/
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment