Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
268d8a77
Commit
268d8a77
authored
Mar 14, 2025
by
gaoqiong
Browse files
增加线性int8 gemm配置
parent
92545504
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
13 deletions
+49
-13
vllm/model_executor/layers/quantization/utils/int8_utils.py
vllm/model_executor/layers/quantization/utils/int8_utils.py
+49
-13
No files found.
vllm/model_executor/layers/quantization/utils/int8_utils.py
View file @
268d8a77
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/sgl-project/sglang/pull/3730
import
functools
import
json
import
logging
...
...
@@ -336,23 +335,60 @@ def w8a8_block_int8_matmul(
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
configs
=
get_w8a8_block_int8_configs
(
N
,
K
,
block_size
[
0
],
block_size
[
1
])
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
#
configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
#
if configs:
#
# If an optimal configuration map has been found, look up the
#
# optimal config
#
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
#
else:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
#print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1]))
# config = {
# "BLOCK_SIZE_M": 32, #64
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
if
M
<=
64
:
config
=
{
"BLOCK_SIZE_M"
:
32
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
]
,
"BLOCK_SIZE_K"
:
block_size
[
1
]
,
"GROUP_SIZE_M"
:
3
2
,
"BLOCK_SIZE_M"
:
16
,
#64
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
3
,
"num_stages"
:
0
,
}
elif
M
<
128
:
config
=
{
"BLOCK_SIZE_M"
:
32
,
#64
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
elif
M
<=
256
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
#64
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
#64
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
8
,
"num_stages"
:
0
,
}
def
grid
(
META
):
return
(
...
...
@@ -514,4 +550,4 @@ def block_dequant(
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
k
),
]
*=
x_s
[
j
][
i
]
return
x_dq_block
\ No newline at end of file
return
x_dq_block
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment