Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1e6b1153
Unverified
Commit
1e6b1153
authored
Dec 12, 2025
by
Wentao Ye
Committed by
GitHub
Dec 12, 2025
Browse files
[Refactor] Reduce duplicate code in `per_token_group_quant` cuda kernels (#30496)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
13618626
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
83 additions
and
98 deletions
+83
-98
csrc/quantization/w8a8/fp8/per_token_group_quant.cu
csrc/quantization/w8a8/fp8/per_token_group_quant.cu
+83
-98
No files found.
csrc/quantization/w8a8/fp8/per_token_group_quant.cu
View file @
1e6b1153
...
...
@@ -22,6 +22,62 @@ __device__ __forceinline__ float GroupReduceMax(float val) {
return
val
;
}
template
<
typename
T
,
bool
SCALE_UE8M0
>
__device__
__forceinline__
float
ComputeGroupScale
(
const
T
*
__restrict__
group_input
,
T
*
__restrict__
smem_group
,
const
int
group_size
,
const
int
lane_id
,
const
int
threads_per_group
,
const
float
eps
,
const
float
max_8bit
)
{
float
local_absmax
=
eps
;
constexpr
int
vec_size
=
16
/
sizeof
(
T
);
// copy global -> shared & compute absmax
auto
scalar_op_cache
=
[
&
]
__device__
(
T
&
dst
,
const
T
&
src
)
{
float
abs_v
=
fabsf
(
static_cast
<
float
>
(
src
));
local_absmax
=
fmaxf
(
local_absmax
,
abs_v
);
dst
=
src
;
};
vllm
::
vectorize_with_alignment
<
vec_size
>
(
group_input
,
// in
smem_group
,
// out (shared)
group_size
,
// elements per group
lane_id
,
// thread id
threads_per_group
,
// stride in group
scalar_op_cache
);
// scalar handler
local_absmax
=
GroupReduceMax
(
local_absmax
);
float
y_s
=
local_absmax
/
max_8bit
;
if
constexpr
(
SCALE_UE8M0
)
{
y_s
=
exp2f
(
ceilf
(
log2f
(
fmaxf
(
fabsf
(
y_s
),
1e-10
f
))));
}
return
y_s
;
}
template
<
typename
T
,
typename
DST_DTYPE
>
__device__
__forceinline__
void
QuantizeGroup
(
const
T
*
__restrict__
smem_group
,
DST_DTYPE
*
__restrict__
group_output
,
const
int
group_size
,
const
int
lane_id
,
const
int
threads_per_group
,
const
float
y_s
,
const
float
min_8bit
,
const
float
max_8bit
)
{
constexpr
int
vec_size
=
16
/
sizeof
(
T
);
// quantize shared -> global 8-bit
auto
scalar_op_quant
=
[
&
]
__device__
(
DST_DTYPE
&
dst
,
const
T
&
src
)
{
float
q
=
fminf
(
fmaxf
(
static_cast
<
float
>
(
src
)
/
y_s
,
min_8bit
),
max_8bit
);
dst
=
DST_DTYPE
(
q
);
};
vllm
::
vectorize_with_alignment
<
vec_size
>
(
smem_group
,
// in (shared)
group_output
,
// out (global quant tensor)
group_size
,
// elements
lane_id
,
// tid
threads_per_group
,
// stride
scalar_op_quant
);
// scalar handler
}
template
<
typename
T
,
typename
DST_DTYPE
,
bool
IS_COLUMN_MAJOR
=
false
,
bool
SCALE_UE8M0
=
false
,
typename
scale_packed_t
=
float
>
__global__
void
per_token_group_quant_8bit_kernel
(
...
...
@@ -38,8 +94,6 @@ __global__ void per_token_group_quant_8bit_kernel(
const
int64_t
global_group_id
=
block_group_id
+
local_group_id
;
const
int64_t
block_group_offset
=
global_group_id
*
group_size
;
float
local_absmax
=
eps
;
using
scale_element_t
=
float
;
static_assert
(
sizeof
(
scale_packed_t
)
%
sizeof
(
scale_element_t
)
==
0
);
...
...
@@ -68,30 +122,9 @@ __global__ void per_token_group_quant_8bit_kernel(
T
*
smem
=
reinterpret_cast
<
T
*>
(
smem_raw
);
T
*
smem_group
=
smem
+
local_group_id
*
group_size
;
constexpr
int
vec_size
=
16
/
sizeof
(
T
);
using
vec_t
=
vllm
::
vec_n_t
<
T
,
vec_size
>
;
// copy global -> shared & compute absmax
auto
scalar_op_cache
=
[
&
]
__device__
(
T
&
dst
,
const
T
&
src
)
{
float
abs_v
=
fabsf
(
static_cast
<
float
>
(
src
));
local_absmax
=
fmaxf
(
local_absmax
,
abs_v
);
dst
=
src
;
};
vllm
::
vectorize_with_alignment
<
vec_size
>
(
group_input
,
// in
smem_group
,
// out (shared)
group_size
,
// elements per group
lane_id
,
// thread id
threads_per_group
,
// stride in group
scalar_op_cache
);
// scalar handler
local_absmax
=
GroupReduceMax
(
local_absmax
);
float
y_s
=
local_absmax
/
max_8bit
;
if
constexpr
(
SCALE_UE8M0
)
{
y_s
=
exp2f
(
ceilf
(
log2f
(
fmaxf
(
fabsf
(
y_s
),
1e-10
f
))));
}
const
float
y_s
=
ComputeGroupScale
<
T
,
SCALE_UE8M0
>
(
group_input
,
smem_group
,
group_size
,
lane_id
,
threads_per_group
,
eps
,
max_8bit
);
scale_element_t
y_s_quant
=
y_s
;
...
...
@@ -101,19 +134,24 @@ __global__ void per_token_group_quant_8bit_kernel(
__syncthreads
();
// quantize shared -> global 8-bit
auto
scalar_op_quant
=
[
&
]
__device__
(
DST_DTYPE
&
dst
,
const
T
&
src
)
{
float
q
=
fminf
(
fmaxf
(
static_cast
<
float
>
(
src
)
/
y_s
,
min_8bit
),
max_8bit
);
dst
=
DST_DTYPE
(
q
);
};
QuantizeGroup
<
T
,
DST_DTYPE
>
(
smem_group
,
group_output
,
group_size
,
lane_id
,
threads_per_group
,
y_s
,
min_8bit
,
max_8bit
);
}
vllm
::
vectorize_with_alignment
<
vec_size
>
(
smem_group
,
// in (shared)
group_output
,
// out (global quant tensor)
group_size
,
// elements
lane_id
,
// tid
threads_per_group
,
// stride
scalar_op_quant
);
// scalar handler
inline
int
GetGroupsPerBlock
(
int64_t
num_groups
)
{
if
(
num_groups
%
16
==
0
)
{
return
16
;
}
if
(
num_groups
%
8
==
0
)
{
return
8
;
}
if
(
num_groups
%
4
==
0
)
{
return
4
;
}
if
(
num_groups
%
2
==
0
)
{
return
2
;
}
return
1
;
}
void
per_token_group_quant_8bit
(
const
torch
::
Tensor
&
input
,
...
...
@@ -133,17 +171,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
constexpr
int
THREADS_PER_GROUP
=
16
;
int
groups_per_block
=
1
;
if
(
num_groups
%
16
==
0
)
{
groups_per_block
=
16
;
}
else
if
(
num_groups
%
8
==
0
)
{
groups_per_block
=
8
;
}
else
if
(
num_groups
%
4
==
0
)
{
groups_per_block
=
4
;
}
else
if
(
num_groups
%
2
==
0
)
{
groups_per_block
=
2
;
}
const
int
groups_per_block
=
GetGroupsPerBlock
(
num_groups
);
auto
dst_type
=
output_q
.
scalar_type
();
const
int
num_blocks
=
num_groups
/
groups_per_block
;
...
...
@@ -225,8 +253,6 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
const
int64_t
block_group_offset
=
global_group_id
*
group_size
;
float
local_absmax
=
eps
;
const
T
*
group_input
=
input
+
block_group_offset
;
DST_DTYPE
*
group_output
=
static_cast
<
DST_DTYPE
*>
(
output_q
)
+
block_group_offset
;
...
...
@@ -235,29 +261,9 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
extern
__shared__
__align__
(
16
)
char
smem_raw
[];
T
*
smem
=
reinterpret_cast
<
T
*>
(
smem_raw
);
T
*
smem_group
=
smem
+
local_group_id
*
group_size
;
constexpr
int
vec_size
=
16
/
sizeof
(
T
);
using
vec_t
=
vllm
::
vec_n_t
<
T
,
vec_size
>
;
// copy global -> shared & compute absmax
auto
scalar_op_cache
=
[
&
]
__device__
(
T
&
dst
,
const
T
&
src
)
{
float
abs_v
=
fabsf
(
static_cast
<
float
>
(
src
));
local_absmax
=
fmaxf
(
local_absmax
,
abs_v
);
dst
=
src
;
};
vllm
::
vectorize_with_alignment
<
vec_size
>
(
group_input
,
// in
smem_group
,
// out (shared)
group_size
,
// elements per group
lane_id
,
// thread id
threads_per_group
,
// stride in group
scalar_op_cache
);
// scalar handler
local_absmax
=
GroupReduceMax
(
local_absmax
);
float
y_s
=
local_absmax
/
max_8bit
;
y_s
=
exp2f
(
ceilf
(
log2f
(
fmaxf
(
fabsf
(
y_s
),
1e-10
f
))));
const
float
y_s
=
ComputeGroupScale
<
T
,
true
>
(
group_input
,
smem_group
,
group_size
,
lane_id
,
threads_per_group
,
eps
,
max_8bit
);
// pack 4 scales into a uint32
if
(
lane_id
==
0
)
{
...
...
@@ -284,19 +290,8 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
__syncthreads
();
// quantize shared -> global 8-bit
auto
scalar_op_quant
=
[
&
]
__device__
(
DST_DTYPE
&
dst
,
const
T
&
src
)
{
float
q
=
fminf
(
fmaxf
(
static_cast
<
float
>
(
src
)
/
y_s
,
min_8bit
),
max_8bit
);
dst
=
DST_DTYPE
(
q
);
};
vllm
::
vectorize_with_alignment
<
vec_size
>
(
smem_group
,
// in (shared)
group_output
,
// out (global quant tensor)
group_size
,
// elements
lane_id
,
// tid
threads_per_group
,
// stride
scalar_op_quant
);
// scalar handler
QuantizeGroup
<
T
,
DST_DTYPE
>
(
smem_group
,
group_output
,
group_size
,
lane_id
,
threads_per_group
,
y_s
,
min_8bit
,
max_8bit
);
}
void
per_token_group_quant_8bit_packed
(
const
torch
::
Tensor
&
input
,
...
...
@@ -337,17 +332,7 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input,
constexpr
int
THREADS_PER_GROUP
=
16
;
int
groups_per_block
=
1
;
if
(
num_groups
%
16
==
0
)
{
groups_per_block
=
16
;
}
else
if
(
num_groups
%
8
==
0
)
{
groups_per_block
=
8
;
}
else
if
(
num_groups
%
4
==
0
)
{
groups_per_block
=
4
;
}
else
if
(
num_groups
%
2
==
0
)
{
groups_per_block
=
2
;
}
const
int
groups_per_block
=
GetGroupsPerBlock
(
num_groups
);
auto
dst_type
=
output_q
.
scalar_type
();
const
int
num_blocks
=
num_groups
/
groups_per_block
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment