Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f3deca99
"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "862f2ef893d9751db0a92bd2d4ae0e3d9677872f"
Commit
f3deca99
authored
Mar 26, 2025
by
gaoqiong
Browse files
增加blockint8支持优化
parent
5c241fa9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
314 additions
and
71 deletions
+314
-71
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+156
-17
vllm/model_executor/layers/quantization/utils/int8_utils.py
vllm/model_executor/layers/quantization/utils/int8_utils.py
+82
-53
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+44
-1
vllm/utils.py
vllm/utils.py
+32
-0
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
f3deca99
...
@@ -21,6 +21,91 @@ from vllm.platforms import current_platform
...
@@ -21,6 +21,91 @@ from vllm.platforms import current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
if
device_name
==
'BW200'
or
device_name
==
'K100_AI'
:
stage1_best_config
=
[
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#0
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#2
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#3
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#4
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#5
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#6
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#7
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#8
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#9
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#10
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
8
},
#11
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#12
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
4
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#15
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#32
]
stage2_best_config
=
[
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#0
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#2
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#3
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#4
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#5
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#6
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#7
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#8
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#9
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#10
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#11
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
8
},
#12
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#15
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#16
]
else
:
stage1_best_config
=
[
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#0
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#2
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#3
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#4
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#5
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#6
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
8
},
#7
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#8
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#9
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#10
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#11
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#12
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#14
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#15
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#32
]
stage2_best_config
=
[
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#0
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#2
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#3
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#4
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#5
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#6
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#7
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#8
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#9
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#10
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#11
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#12
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#14
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#15
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#16
]
@
triton
.
jit
@
triton
.
jit
def
fused_moe_kernel_awq
(
def
fused_moe_kernel_awq
(
...
@@ -1516,23 +1601,24 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1516,23 +1601,24 @@ def fused_experts_impl(hidden_states: torch.Tensor,
# https://github.com/vllm-project/vllm/issues/5938
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
use_fp8_w8a8
,
if
not
use_int8_w8a8
:
use_int8_w8a8
=
use_int8_w8a8
,
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a16
=
use_int8_w8a16
,
dtype
=
hidden_states
.
dtype
)
use_int4_w4a16
=
use_int4_w4a16
,
dtype
=
hidden_states
.
dtype
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
get_config_func
=
functools
.
partial
(
w1
.
shape
,
try_get_optimal_moe_config
,
w2
.
shape
,
w1
.
shape
,
topk_ids
.
shape
[
1
],
w2
.
shape
,
config_dtype
,
topk_ids
.
shape
[
1
],
block_shape
=
block_shape
,
config_dtype
,
use_nn_moe
=
use_nn_moe
,
block_shape
=
block_shape
,
)
use_nn_moe
=
use_nn_moe
,
)
config
=
get_config_func
(
M
)
config
=
get_config_func
(
M
)
intermediate_cache1
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
N
),
intermediate_cache1
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
...
@@ -1584,6 +1670,33 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1584,6 +1670,33 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
if
use_int8_w8a8
:
m
=
curr_hidden_states
.
shape
[
0
]
if
m
<=
16
:
config
=
stage1_best_config
[
m
-
1
]
elif
m
<=
32
:
config
=
stage1_best_config
[
15
]
elif
m
<=
64
:
config
=
stage1_best_config
[
16
]
elif
m
<
256
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
4
}
if
moe_ep_size
==
1
:
if
moe_ep_size
==
1
:
if
use_int4_w4a16
:
if
use_int4_w4a16
:
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
...
@@ -1620,7 +1733,33 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1620,7 +1733,33 @@ def fused_experts_impl(hidden_states: torch.Tensor,
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
intermediate_cache1
.
view
(
-
1
,
N
))
if
use_int8_w8a8
:
m1
=
intermediate_cache2
.
shape
[
0
]
if
m1
<=
16
:
config
=
stage2_best_config
[
m1
-
1
]
elif
m1
<=
32
:
config
=
stage2_best_config
[
15
]
elif
m1
<=
64
:
config
=
stage2_best_config
[
16
]
elif
m1
<
256
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
4
}
invoke_fused_moe_kernel
(
intermediate_cache2
,
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
w2
,
intermediate_cache3
,
intermediate_cache3
,
...
...
vllm/model_executor/layers/quantization/utils/int8_utils.py
View file @
f3deca99
...
@@ -9,12 +9,13 @@ from typing import Any, Dict, List, Optional, Tuple
...
@@ -9,12 +9,13 @@ from typing import Any, Dict, List, Optional, Tuple
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
vllm.utils
import
W8a8GetCacheJSON
# from sglang.srt.utils import get_device_name
# from sglang.srt.utils import get_device_name
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
@
triton
.
jit
@
triton
.
jit
def
_per_token_quant_int8
(
def
_per_token_quant_int8
(
...
@@ -335,16 +336,16 @@ def w8a8_block_int8_matmul(
...
@@ -335,16 +336,16 @@ def w8a8_block_int8_matmul(
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
#configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
#
configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
#if configs:
#
if configs:
# # If an optimal configuration map has been found, look up the
# # If an optimal configuration map has been found, look up the
# # optimal config
# # optimal config
# config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
#
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
#else:
#
else:
#
Default config
#
#Default config
#
Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
#
#Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
#print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1]))
#
#print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1]))
# config = {
#
config = {
# "BLOCK_SIZE_M": 32, #64
# "BLOCK_SIZE_M": 32, #64
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "BLOCK_SIZE_K": block_size[1],
...
@@ -352,43 +353,78 @@ def w8a8_block_int8_matmul(
...
@@ -352,43 +353,78 @@ def w8a8_block_int8_matmul(
# "num_warps": 4,
# "num_warps": 4,
# "num_stages": 3,
# "num_stages": 3,
# }
# }
#print("W8A8_TRITONJSON.triton_json_dict[0]:",W8A8_TRITONJSON.triton_json_dict[0])
if
M
<=
64
:
if
len
(
W8A8_TRITONJSON
.
triton_json_dict
)
==
0
:
config
=
{
config
=
None
"BLOCK_SIZE_M"
:
16
,
#64
#print("len(W8A8_TRITONJSON.triton_json_dict)=0:",len(W8A8_TRITONJSON.triton_json_dict))
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
elif
f
"1_
{
N
}
_
{
K
}
_block[
{
block_n
}
,
{
block_k
}
]"
in
W8A8_TRITONJSON
.
triton_json_dict
[
0
]:
"GROUP_SIZE_M"
:
2
,
if
M
<=
16
:
"num_warps"
:
4
,
m_
=
M
"num_stages"
:
0
,
elif
M
<=
64
:
}
m_
=
(
M
+
3
)
&
-
4
#取值到最近的4的倍数
elif
M
<
128
:
elif
M
<=
160
:
config
=
{
m_
=
(
M
+
7
)
&
-
8
"BLOCK_SIZE_M"
:
32
,
#64
"BLOCK_SIZE_N"
:
64
,
elif
M
<
200
:
#256
"BLOCK_SIZE_K"
:
128
,
m_
=
160
"GROUP_SIZE_M"
:
2
,
elif
M
<
480
:
#512
"num_warps"
:
4
,
m_
=
256
"num_stages"
:
0
,
elif
M
<
960
:
#1024
}
m_
=
512
elif
M
<=
256
:
elif
M
<
2048
:
config
=
{
m_
=
1024
"BLOCK_SIZE_M"
:
64
,
#64
elif
M
<
4096
:
"BLOCK_SIZE_N"
:
64
,
m_
=
2048
"BLOCK_SIZE_K"
:
128
,
elif
M
<
6000
:
"GROUP_SIZE_M"
:
2
,
m_
=
4096
"num_warps"
:
4
,
else
:
"num_stages"
:
0
,
m_
=
8192
}
#print("==================m:{},n:{},k:{}".format(M,N,K))
else
:
config
=
W8A8_TRITONJSON
.
triton_json_dict
[
0
][
f
"
{
m_
}
_
{
N
}
_
{
K
}
_block[
{
block_n
}
,
{
block_k
}
]"
]
config
=
{
"BLOCK_SIZE_M"
:
64
,
#64
else
:
"BLOCK_SIZE_N"
:
128
,
config
=
None
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
# print("m:{},n:{},k:{}".format(M,N,K))
"num_warps"
:
8
,
# print("config not found!")
"num_stages"
:
0
,
}
if
M
<=
64
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
#64
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
elif
M
<
128
:
config
=
{
"BLOCK_SIZE_M"
:
32
,
#64
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
elif
M
<=
256
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
#64
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
#64
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
8
,
"num_stages"
:
0
,
}
def
grid
(
META
):
def
grid
(
META
):
return
(
return
(
...
@@ -475,8 +511,6 @@ def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.f
...
@@ -475,8 +511,6 @@ def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.f
C
=
C
.
reshape
(
origin_C_shape
).
to
(
output_dtype
)
C
=
C
.
reshape
(
origin_C_shape
).
to
(
output_dtype
)
return
C
return
C
def
apply_w8a8_block_int8_linear
(
def
apply_w8a8_block_int8_linear
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
@@ -497,11 +531,6 @@ def apply_w8a8_block_int8_linear(
...
@@ -497,11 +531,6 @@ def apply_w8a8_block_int8_linear(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
input
.
dtype
output_dtype
=
input
.
dtype
)
)
# output = native_w8a8_block_int8_matmul(
# q_input, weight, x_scale, weight_scale, block_size,
# output_dtype=input.dtype
# )
if
bias
is
not
None
:
if
bias
is
not
None
:
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
f3deca99
...
@@ -53,6 +53,7 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -53,6 +53,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader
,
maybe_remap_kv_scale_name
)
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
W8a8GetCacheJSON
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
...
@@ -677,6 +678,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
...
@@ -677,6 +678,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LM_NN'
]
=
'0'
os
.
environ
[
'LM_NN'
]
=
'0'
self
.
use_w4a16_moe_sz
=
os
.
environ
.
get
(
'AWQ_MOE_SZ'
)
==
'1'
self
.
use_w4a16_moe_sz
=
os
.
environ
.
get
(
'AWQ_MOE_SZ'
)
==
'1'
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
parallel_config
=
vllm_config
.
parallel_config
...
@@ -948,7 +950,48 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
...
@@ -948,7 +950,48 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
scales
=
params_dict
[
layername
.
replace
(
"qweight"
,
"scales"
)]
scales
=
params_dict
[
layername
.
replace
(
"qweight"
,
"scales"
)]
sz_tensor
=
self
.
restore_qzeros_tensor
(
qzeros
,
scales
)
sz_tensor
=
self
.
restore_qzeros_tensor
(
qzeros
,
scales
)
scales
.
data
=
sz_tensor
scales
.
data
=
sz_tensor
if
hasattr
(
self
.
config
,
"quantization_config"
)
and
self
.
config
.
quantization_config
[
"quant_method"
]
==
"blockwise_int8"
:
lay_key_words
=
[
"self_attn.q_a_proj.weight"
,
"self_attn.q_b_proj.weight"
,
"self_attn.kv_b_proj.weight"
,
"self_attn.kv_a_proj_with_mqa.weight"
,
"self_attn.o_proj.weight"
,
"mlp.gate_up_proj.weight"
,
"mlp.down_proj.weight"
,
"mlp.shared_experts.gate_up_proj.weight"
,
"mlp.shared_experts.down_proj.weight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
weight_shapes
=
[]
all_json
=
{}
matched_key_words
=
set
()
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
and
"scale"
not
in
layername
:
weight_data
=
params_dict
[
layername
]
n
=
weight_data
.
shape
[
0
]
if
len
(
matched_key_words
)
<
9
and
matches
[
0
]
not
in
matched_key_words
:
matched_key_words
.
add
(
matches
[
0
])
k
=
weight_data
.
shape
[
1
]
weight_shapes
.
append
({
n
,
k
})
#print("n:{},k:{}".format(n,k))
json_file
=
self
.
tritonsingleton
.
get_blockint8json_name
(
n
,
k
,
128
,
128
)
configs_dict
=
self
.
tritonsingleton
.
get_blockint8_triton_cache
(
json_file
,
n
,
k
,
128
,
128
)
if
configs_dict
:
all_json
.
update
(
configs_dict
)
self
.
tritonsingleton
.
triton_json_dict
.
append
(
all_json
)
#print("self.tritonsingleton.triton_json_dict[0].shape:",len(self.tritonsingleton.triton_json_dict[0]))
for
key
,
value
in
all_json
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
n
=
int
(
key
.
split
(
'_'
)[
1
])
k
=
int
(
key
.
split
(
'_'
)[
2
])
# ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
return
loaded_params
return
loaded_params
...
...
vllm/utils.py
View file @
f3deca99
...
@@ -1611,6 +1611,38 @@ class W8a8GetCacheJSON:
...
@@ -1611,6 +1611,38 @@ class W8a8GetCacheJSON:
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
return
self
.
triton_json_dir
+
f
"/W8A8_
{
n
}
_
{
k
}
_
{
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/W8A8_
{
n
}
_
{
k
}
_
{
device_name
}
.json"
def
get_blockint8_triton_cache
(
self
,
file_path
,
n
,
k
,
block_n
,
block_k
):
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
return
None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_value
=
{
'BLOCK_SIZE_M'
:
int
(
sub_value
[
"BLOCK_SIZE_M"
]),
'BLOCK_SIZE_N'
:
int
(
sub_value
[
"BLOCK_SIZE_N"
]),
'BLOCK_SIZE_K'
:
int
(
sub_value
[
"BLOCK_SIZE_K"
]),
'GROUP_SIZE_M'
:
int
(
sub_value
[
"GROUP_SIZE_M"
]),
'kpack'
:
int
(
sub_value
[
"kpack"
]),
'num_stages'
:
int
(
sub_value
[
'num_stages'
]),
'num_warps'
:
int
(
sub_value
[
'num_warps'
]),
}
configs_dict
[
configs_key
]
=
configs_value
return
configs_dict
def
get_blockint8json_name
(
self
,
n
,
k
,
block_n
,
block_k
):
from
vllm.platforms
import
current_platform
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
return
self
.
triton_json_dir
+
f
"/linear_
{
n
}
_
{
k
}
_block[
{
block_n
}
,
{
block_k
}
]_
{
device_name
}
.json"
# Adapted from: https://stackoverflow.com/a/47212782/5082708
# Adapted from: https://stackoverflow.com/a/47212782/5082708
class
LazyDict
(
Mapping
[
str
,
T
],
Generic
[
T
]):
class
LazyDict
(
Mapping
[
str
,
T
],
Generic
[
T
]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment