Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e7a963f5
Commit
e7a963f5
authored
Jun 07, 2025
by
yangql
Browse files
新增fusemoe手写算子的支持,需要group-gemm包
parent
6880bf15
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
335 additions
and
1 deletion
+335
-1
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+26
-1
vllm/model_executor/layers/quantization/utils/fused_moe_cuda.py
...odel_executor/layers/quantization/utils/fused_moe_cuda.py
+309
-0
No files found.
vllm/model_executor/layers/quantization/moe_wna16.py
View file @
e7a963f5
...
...
@@ -15,7 +15,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supports_layer
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.model_executor.layers.fused_moe
import
fused_experts
os
.
environ
[
'W4A16_MOE_CUDA'
]
=
os
.
environ
.
get
(
'W4A16_MOE_CUDA'
,
'0'
)
if
os
.
environ
[
'W4A16_MOE_CUDA'
]
==
'1'
:
from
vllm.model_executor.layers.quantization.utils.fused_moe_cuda
import
fused_experts_cuda
class
MoeWNA16Config
(
QuantizationConfig
):
"""Config class for MOE WNA16 (W8A16/W4A16) quantization."""
...
...
@@ -176,6 +180,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
def
__init__
(
self
,
quant_config
:
MoeWNA16Config
):
self
.
quant_config
=
quant_config
self
.
use_w4a16_moe_sz
=
os
.
environ
.
get
(
'AWQ_MOE_SZ'
)
==
'1'
self
.
use_w4a16_cuda
=
os
.
environ
[
'W4A16_MOE_CUDA'
]
==
'1'
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
...
...
@@ -329,7 +334,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
...
...
@@ -347,6 +352,26 @@ class MoeWNA16Method(FusedMoEMethodBase):
weight_bits
=
self
.
quant_config
.
weight_bits
has_zp
=
self
.
quant_config
.
has_zp
if
self
.
use_w4a16_cuda
:
m
=
topk_ids
.
shape
[
0
]
if
m
<=
64
:
return
fused_experts_cuda
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
topk_weights
,
topk_ids
,
inplace
=
True
,
use_fp8_w8a8
=
False
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int8_w8a16
=
False
,
w1_scale
=
layer
.
w13_scales
,
w2_scale
=
layer
.
w2_scales
,
w1_zp
=
None
,
w2_zp
=
None
,
a1_scale
=
None
,
a2_scale
=
None
,
block_shape
=
[
0
,
layer
.
group_size
],
expert_map
=
expert_map
)
return
fused_experts
(
x
,
...
...
vllm/model_executor/layers/quantization/utils/fused_moe_cuda.py
0 → 100644
View file @
e7a963f5
# SPDX-License-Identifier: Apache-2.0
"""Fused MoE kernel."""
import
functools
import
json
import
os
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
)
from
grouped_gemm
import
moe_gemm_w4a16
from
grouped_gemm.ops
import
permute
as
permute_topK
,
unpermute
as
unpermute_topK
import
torch.nn.functional
as
F
logger
=
init_logger
(
__name__
)
def
config_cuda
(
M
):
bw_gemm1_mode_dict
=
{
1
:
83
,
2
:
77
,
3
:
32
,
4
:
38
,
5
:
38
,
6
:
87
,
7
:
82
,
8
:
42
,
9
:
83
,
10
:
42
,
11
:
42
,
12
:
87
,
13
:
42
,
14
:
38
,
15
:
42
,
16
:
42
,
17
:
42
,
18
:
87
,
19
:
87
,
20
:
83
,
21
:
83
,
22
:
83
,
23
:
83
,
24
:
27
,
25
:
42
,
26
:
83
,
27
:
38
,
28
:
42
,
29
:
42
,
30
:
38
,
31
:
42
,
32
:
38
}
bw_gemm2_mode_dict
=
{
1
:
23
,
2
:
88
,
3
:
74
,
4
:
39
,
5
:
43
,
6
:
88
,
7
:
88
,
8
:
89
,
9
:
73
,
10
:
88
,
11
:
88
,
12
:
88
,
13
:
88
,
14
:
88
,
15
:
88
,
16
:
88
,
17
:
88
,
18
:
43
,
19
:
88
,
20
:
43
,
21
:
43
,
22
:
43
,
23
:
88
,
24
:
88
,
25
:
88
,
26
:
88
,
27
:
88
,
28
:
88
,
29
:
88
,
30
:
43
,
31
:
88
,
32
:
88
}
k100ai_gemm1_mode_dict
=
{
1
:
79
,
2
:
34
,
3
:
34
,
4
:
34
,
6
:
34
,
8
:
34
,
16
:
34
,
24
:
34
,
32
:
34
,
}
k100ai_gemm2_mode_dict
=
{
1
:
64
,
2
:
33
,
3
:
33
,
4
:
37
,
5
:
37
,
6
:
33
,
7
:
33
,
8
:
37
,
9
:
37
,
10
:
37
,
11
:
37
,
12
:
37
,
13
:
37
,
14
:
38
,
15
:
38
,
16
:
72
,
17
:
72
,
18
:
72
,
19
:
72
,
20
:
72
,
21
:
72
,
22
:
72
,
23
:
72
,
24
:
39
,
25
:
39
,
26
:
39
,
27
:
39
,
28
:
39
,
29
:
39
,
30
:
39
,
31
:
39
,
32
:
39
,
}
device_name
=
device_name
=
current_platform
.
get_device_name
()
if
"BW"
in
device_name
:
gemm1_mode_dict
=
bw_gemm1_mode_dict
gemm2_mode_dict
=
bw_gemm2_mode_dict
else
:
gemm1_mode_dict
=
k100ai_gemm1_mode_dict
gemm2_mode_dict
=
k100ai_gemm2_mode_dict
mode_1
=
gemm1_mode_dict
.
get
(
M
,
gemm1_mode_dict
[
32
])
mode_2
=
gemm2_mode_dict
.
get
(
M
,
gemm2_mode_dict
[
32
])
return
mode_1
,
mode_2
def
fused_experts_cuda
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,):
if
inplace
:
fused_experts_impl_cuda
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
expert_map
)
return
hidden_states
else
:
return
fused_experts_impl_cuda
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
expert_map
)
def
fused_experts_impl_cuda
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,):
# Check constraints.
assert
hidden_states
.
shape
[
1
]
//
2
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
num_tokens
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
32768
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
# config = get_config_func(M)
intermediate_cache1
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache2
=
torch
.
empty
((
M
*
topk_ids
.
shape
[
1
],
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache3
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
if
hidden_states
.
dtype
==
torch
.
bfloat16
:
compute_type
=
tl
.
bfloat16
elif
hidden_states
.
dtype
==
torch
.
float16
:
compute_type
=
tl
.
float16
elif
hidden_states
.
dtype
==
torch
.
float32
:
compute_type
=
tl
.
float32
else
:
raise
ValueError
(
f
"Unsupported compute_type:
{
hidden_states
.
dtype
}
"
)
if
inplace
:
out_hidden_states
=
hidden_states
else
:
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
for
chunk
in
range
((
num_tokens
//
CHUNK_SIZE
)
+
1
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
))
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
if
tokens_in_chunk
==
0
:
break
if
tokens_in_chunk
<
CHUNK_SIZE
and
chunk
>
0
:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1
=
intermediate_cache1
[:
tokens_in_chunk
]
intermediate_cache2
=
intermediate_cache2
[:
tokens_in_chunk
*
topk_ids
.
shape
[
1
]]
intermediate_cache3
=
intermediate_cache3
[:
tokens_in_chunk
]
# config = get_config_func(tokens_in_chunk)
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
16
,
E
,
expert_map
,
curr_hidden_states
.
shape
[
0
]))
mode_1
,
mode_2
=
config_cuda
(
M
)
expert_ids
=
expert_ids
[:
num_tokens_post_padded
//
16
]
moe_gemm_w4a16
.
gemm1_w4a16
(
sorted_token_ids
.
to
(
torch
.
uint16
),
# sorted_token_ids.to(torch.uint16)
curr_hidden_states
,
# hidden_states
w1
,
# w1
intermediate_cache1
,
# gemm1_out
expert_ids
,
# expert_id_vec
w1_scale
,
# scale_zero
64
,
# group_size
topk
=
topk_ids
.
shape
[
1
],
# topk
mode
=
mode_1
)
# mode=gemm1_mode
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
moe_gemm_w4a16
.
gemm2_w4a16
(
sorted_token_ids
.
to
(
torch
.
uint16
),
# sorted_token_ids.to(torch.uint16)
intermediate_cache2
,
# hidden_states
w2
,
# w2
intermediate_cache3
,
# gemm2_out
expert_ids
,
# expert_id_vec
w2_scale
,
# scale_zero
curr_topk_weights
,
# topk_weights
64
,
# group_size
topk
=
topk_ids
.
shape
[
1
],
# topk
mode
=
mode_2
)
# mode=gemm2_mode
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
return
out_hidden_states
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment