Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
783921d8
Unverified
Commit
783921d8
authored
Jul 04, 2025
by
Wentao Ye
Committed by
GitHub
Jul 04, 2025
Browse files
[Perf] Optimize Vectorization Utils for Int 8 Quantization Kernels (#20331)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
4a98edff
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
106 additions
and
7 deletions
+106
-7
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+9
-7
csrc/quantization/vectorization_utils.cuh
csrc/quantization/vectorization_utils.cuh
+97
-0
No files found.
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
View file @
783921d8
...
@@ -162,10 +162,11 @@ __global__ void dynamic_scaled_int8_quant_kernel(
...
@@ -162,10 +162,11 @@ __global__ void dynamic_scaled_int8_quant_kernel(
// calculate for absmax
// calculate for absmax
float
thread_max
=
0.
f
;
float
thread_max
=
0.
f
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
stride
)
{
vectorize_read_with_alignment
<
16
>
(
const
auto
v
=
fabsf
(
static_cast
<
float
>
(
row_in
[
i
]));
row_in
,
hidden_size
,
tid
,
stride
,
[
&
]
__device__
(
const
scalar_t
&
src
)
{
thread_max
=
fmaxf
(
thread_max
,
v
);
const
float
v
=
fabsf
(
static_cast
<
float
>
(
src
));
}
thread_max
=
fmaxf
(
thread_max
,
v
);
});
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
256
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
256
>
;
__shared__
typename
BlockReduce
::
TempStorage
tmp
;
__shared__
typename
BlockReduce
::
TempStorage
tmp
;
float
block_max
=
BlockReduce
(
tmp
).
Reduce
(
thread_max
,
cub
::
Max
{},
blockDim
.
x
);
float
block_max
=
BlockReduce
(
tmp
).
Reduce
(
thread_max
,
cub
::
Max
{},
blockDim
.
x
);
...
@@ -232,9 +233,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
...
@@ -232,9 +233,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
// 1. calculate min & max
// 1. calculate min & max
MinMax
thread_mm
;
MinMax
thread_mm
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
stride
)
{
vectorize_read_with_alignment
<
16
>
(
row_in
,
hidden_size
,
tid
,
stride
,
thread_mm
+=
static_cast
<
float
>
(
row_in
[
i
]);
[
&
]
__device__
(
const
scalar_t
&
src
)
{
}
thread_mm
+=
static_cast
<
float
>
(
src
);
});
using
BlockReduce
=
cub
::
BlockReduce
<
MinMax
,
256
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
MinMax
,
256
>
;
__shared__
typename
BlockReduce
::
TempStorage
tmp
;
__shared__
typename
BlockReduce
::
TempStorage
tmp
;
...
...
csrc/quantization/vectorization_utils.cuh
View file @
783921d8
...
@@ -27,6 +27,26 @@ __device__ inline void vectorize_with_alignment(
...
@@ -27,6 +27,26 @@ __device__ inline void vectorize_with_alignment(
constexpr
int
WIDTH
=
VEC_SIZE
*
sizeof
(
InT
);
// eg: 64 B
constexpr
int
WIDTH
=
VEC_SIZE
*
sizeof
(
InT
);
// eg: 64 B
uintptr_t
addr
=
reinterpret_cast
<
uintptr_t
>
(
in
);
uintptr_t
addr
=
reinterpret_cast
<
uintptr_t
>
(
in
);
// fast path when the whole region is already aligned
// Note: currently the output is guaranteed to be same as the input, so we
// don't check it here, comments here just for future reference.
bool
can_vec
=
((
addr
&
(
WIDTH
-
1
))
==
0
)
&&
((
len
&
(
VEC_SIZE
-
1
))
==
0
);
if
(
can_vec
)
{
int
num_vec
=
len
/
VEC_SIZE
;
using
vin_t
=
vec_n_t
<
InT
,
VEC_SIZE
>
;
using
vout_t
=
vec_n_t
<
OutT
,
VEC_SIZE
>
;
auto
*
v_in
=
reinterpret_cast
<
const
vin_t
*>
(
in
);
auto
*
v_out
=
reinterpret_cast
<
vout_t
*>
(
out
);
for
(
int
i
=
tid
;
i
<
num_vec
;
i
+=
stride
)
{
vout_t
tmp
;
vec_op
(
tmp
,
v_in
[
i
]);
v_out
[
i
]
=
tmp
;
}
return
;
}
int
misalignment_offset
=
addr
&
(
WIDTH
-
1
);
// addr % 64
int
misalignment_offset
=
addr
&
(
WIDTH
-
1
);
// addr % 64
int
alignment_bytes
=
WIDTH
-
misalignment_offset
;
// 64 - (addr % 64)
int
alignment_bytes
=
WIDTH
-
misalignment_offset
;
// 64 - (addr % 64)
int
prefix_elems
=
alignment_bytes
&
(
WIDTH
-
1
);
// handle 64
int
prefix_elems
=
alignment_bytes
&
(
WIDTH
-
1
);
// handle 64
...
@@ -72,4 +92,81 @@ __device__ __forceinline__ void vectorize_with_alignment(const InT* in,
...
@@ -72,4 +92,81 @@ __device__ __forceinline__ void vectorize_with_alignment(const InT* in,
std
::
forward
<
ScaOp
>
(
scalar_op
));
std
::
forward
<
ScaOp
>
(
scalar_op
));
}
}
template
<
int
VEC_SIZE
,
typename
InT
,
typename
ScaOp
>
struct
DefaultReadVecOp
{
ScaOp
scalar_op
;
__device__
__forceinline__
void
operator
()(
const
vec_n_t
<
InT
,
VEC_SIZE
>&
src
)
const
{
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
scalar_op
(
src
.
val
[
i
]);
}
}
};
// read-only version: iterate over the input with alignment guarantees
template
<
int
VEC_SIZE
,
typename
InT
,
typename
VecOp
,
typename
ScaOp
>
__device__
inline
void
vectorize_read_with_alignment
(
const
InT
*
in
,
int
len
,
int
tid
,
int
stride
,
VecOp
&&
vec_op
,
ScaOp
&&
scalar_op
)
{
static_assert
(
VEC_SIZE
>
0
&&
(
VEC_SIZE
&
(
VEC_SIZE
-
1
))
==
0
,
"VEC_SIZE must be a positive power-of-two"
);
constexpr
int
WIDTH
=
VEC_SIZE
*
sizeof
(
InT
);
uintptr_t
addr
=
reinterpret_cast
<
uintptr_t
>
(
in
);
// fast path when the whole region is already aligned
bool
can_vec
=
((
addr
&
(
WIDTH
-
1
))
==
0
)
&&
((
len
&
(
VEC_SIZE
-
1
))
==
0
);
if
(
can_vec
)
{
int
num_vec
=
len
/
VEC_SIZE
;
using
vin_t
=
vec_n_t
<
InT
,
VEC_SIZE
>
;
auto
*
v_in
=
reinterpret_cast
<
const
vin_t
*>
(
in
);
for
(
int
i
=
tid
;
i
<
num_vec
;
i
+=
stride
)
{
vec_op
(
v_in
[
i
]);
}
return
;
}
int
misalignment_offset
=
addr
&
(
WIDTH
-
1
);
int
alignment_bytes
=
WIDTH
-
misalignment_offset
;
int
prefix_elems
=
alignment_bytes
&
(
WIDTH
-
1
);
prefix_elems
/=
sizeof
(
InT
);
prefix_elems
=
min
(
prefix_elems
,
len
);
// 1. handle the possibly unaligned prefix with scalar access.
for
(
int
i
=
tid
;
i
<
prefix_elems
;
i
+=
stride
)
{
scalar_op
(
in
[
i
]);
}
in
+=
prefix_elems
;
len
-=
prefix_elems
;
int
num_vec
=
len
/
VEC_SIZE
;
using
vin_t
=
vec_n_t
<
InT
,
VEC_SIZE
>
;
auto
*
v_in
=
reinterpret_cast
<
const
vin_t
*>
(
in
);
// 2. vectorized traversal of the main aligned region.
for
(
int
i
=
tid
;
i
<
num_vec
;
i
+=
stride
)
{
vec_op
(
v_in
[
i
]);
}
// 3. handle remaining tail elements.
int
tail_start
=
num_vec
*
VEC_SIZE
;
for
(
int
i
=
tid
+
tail_start
;
i
<
len
;
i
+=
stride
)
{
scalar_op
(
in
[
i
]);
}
}
// overload that requires only a scalar_op
template
<
int
VEC_SIZE
,
typename
InT
,
typename
ScaOp
>
__device__
__forceinline__
void
vectorize_read_with_alignment
(
const
InT
*
in
,
int
len
,
int
tid
,
int
stride
,
ScaOp
&&
scalar_op
)
{
using
Vec
=
DefaultReadVecOp
<
VEC_SIZE
,
InT
,
std
::
decay_t
<
ScaOp
>>
;
vectorize_read_with_alignment
<
VEC_SIZE
>
(
in
,
len
,
tid
,
stride
,
Vec
{
scalar_op
},
std
::
forward
<
ScaOp
>
(
scalar_op
));
}
}
// namespace vllm
}
// namespace vllm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment