Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f52afe3f
Commit
f52afe3f
authored
May 29, 2025
by
gaoqiong
Browse files
增加w8a8 线性gemm config优化
parent
7f301a2c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
33 deletions
+26
-33
vllm/model_executor/layers/quantization/utils/int8_utils.py
vllm/model_executor/layers/quantization/utils/int8_utils.py
+25
-33
vllm/utils.py
vllm/utils.py
+1
-0
No files found.
vllm/model_executor/layers/quantization/utils/int8_utils.py
View file @
f52afe3f
...
...
@@ -300,8 +300,8 @@ def _w8a8_block_int8_matmul(
GROUP_SIZE_M
:
tl
.
constexpr
,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization,
and
store the result in output tensor `C`.
product) on input tensors `A` and `B` with block-wise quantization,
and
store the result in output tensor `C`.
"""
pid
=
tl
.
program_id
(
axis
=
0
)
...
...
@@ -316,16 +316,29 @@ def _w8a8_block_int8_matmul(
offs_am
=
(
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
))
%
M
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
# offs_bsn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_bsn
=
pid_n
*
BLOCK_SIZE_N
//
group_n
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
A
+
(
offs_am
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
B
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
# a_ptrs = A + (offs_am[:, None] * stride_am)
# b_ptrs = B + (offs_bn[None, :] * stride_bn)
As_ptrs
=
As
+
offs_am
*
stride_As_m
offs_bsn
=
offs_bn
//
group_n
#
offs_bsn = offs_bn // group_n
Bs_ptrs
=
Bs
+
offs_bsn
*
stride_Bs_n
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_s
=
tl
.
load
(
As_ptrs
+
offs_ks
*
stride_As_k
)
b_s
=
tl
.
load
(
Bs_ptrs
+
offs_ks
*
stride_Bs_k
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
...
...
@@ -333,16 +346,13 @@ def _w8a8_block_int8_matmul(
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_s
=
tl
.
load
(
As_ptrs
+
offs_ks
*
stride_As_k
)
b_s
=
tl
.
load
(
Bs_ptrs
+
offs_ks
*
stride_Bs_k
)
accumulator
+=
tl
.
dot
(
a
,
b
).
to
(
tl
.
float32
)
*
a_s
[:,
None
]
*
b_s
[
None
,
:]
accumulator
+=
tl
.
dot
(
a
,
b
).
to
(
tl
.
float32
)
*
a_s
[:,
None
]
*
b_s
[
None
,
:]
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
C
.
dtype
.
element_ty
==
tl
.
bfloat16
:
c
=
accumulator
.
to
(
tl
.
bfloat16
)
elif
C
.
dtype
.
element_ty
==
tl
.
float16
:
...
...
@@ -436,27 +446,8 @@ def w8a8_block_int8_matmul(
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,
)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
# configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
# if configs:
# # If an optimal configuration map has been found, look up the
# # optimal config
# config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
# else:
# # Default config
# # Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
# config = {
# "BLOCK_SIZE_M": 64,
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
#print("W8A8_TRITONJSON.triton_json_list[0]:",W8A8_TRITONJSON.triton_json_list[0])
if
len
(
W8A8_TRITONJSON
.
triton_json_list
)
==
0
:
config
=
None
#print("len(W8A8_TRITONJSON.triton_json_list)=0:",len(W8A8_TRITONJSON.triton_json_list)) triton_json
elif
f
"1_
{
N
}
_
{
K
}
_block[
{
block_n
}
,
{
block_k
}
]"
in
W8A8_TRITONJSON
.
triton_json_list
[
0
]:
if
M
<=
16
:
...
...
@@ -480,12 +471,13 @@ def w8a8_block_int8_matmul(
m_
=
4096
else
:
m_
=
8192
#print("==================m:{},n:{},k:{}".format(M,N,K))
config
=
W8A8_TRITONJSON
.
triton_json_list
[
0
][
f
"
{
m_
}
_
{
N
}
_
{
K
}
_block[
{
block_n
}
,
{
block_k
}
]"
]
else
:
config
=
None
if
config
==
None
:
# print("m:{},n:{},k:{}".format(M,N,K))
# print("config not found!")
...
...
vllm/utils.py
View file @
f52afe3f
...
...
@@ -1825,6 +1825,7 @@ class W8a8GetCacheJSON:
'kpack'
:
int
(
sub_value
[
"kpack"
]),
'num_stages'
:
int
(
sub_value
[
'num_stages'
]),
'num_warps'
:
int
(
sub_value
[
'num_warps'
]),
'enable_mmacfuse'
:
int
(
sub_value
[
'enable_mmacfuse'
]),
}
configs_dict
[
configs_key
]
=
configs_value
return
configs_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment