Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
12a59959
Unverified
Commit
12a59959
authored
Jul 02, 2024
by
Avshalom Manevich
Committed by
GitHub
Jul 01, 2024
Browse files
[Bugfix] adding chunking mechanism to fused_moe to handle large inputs (#6029)
parent
dec6fc6f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
48 deletions
+74
-48
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+1
-1
vllm/envs.py
vllm/envs.py
+3
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+70
-47
No files found.
tests/kernels/test_moe.py
View file @
12a59959
...
@@ -29,7 +29,7 @@ def torch_moe(a, w1, w2, score, topk):
...
@@ -29,7 +29,7 @@ def torch_moe(a, w1, w2, score, topk):
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)).
sum
(
dim
=
1
)
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)).
sum
(
dim
=
1
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1024
*
128
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
[
8
,
64
])
@
pytest
.
mark
.
parametrize
(
"e"
,
[
8
,
64
])
...
...
vllm/envs.py
View file @
12a59959
...
@@ -32,6 +32,7 @@ if TYPE_CHECKING:
...
@@ -32,6 +32,7 @@ if TYPE_CHECKING:
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION
:
Optional
[
str
]
=
None
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION
:
Optional
[
str
]
=
None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS
:
bool
=
False
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS
:
bool
=
False
VLLM_XLA_CACHE_PATH
:
str
=
"~/.vllm/xla_cache/"
VLLM_XLA_CACHE_PATH
:
str
=
"~/.vllm/xla_cache/"
VLLM_FUSED_MOE_CHUNK_SIZE
:
int
=
64
*
1024
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
VLLM_WORKER_MULTIPROC_METHOD
:
str
=
"fork"
VLLM_WORKER_MULTIPROC_METHOD
:
str
=
"fork"
VLLM_IMAGE_FETCH_TIMEOUT
:
int
=
5
VLLM_IMAGE_FETCH_TIMEOUT
:
int
=
5
...
@@ -248,6 +249,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -248,6 +249,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Only used for XLA devices such as TPUs.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH"
:
"VLLM_XLA_CACHE_PATH"
:
lambda
:
os
.
getenv
(
"VLLM_XLA_CACHE_PATH"
,
"~/.vllm/xla_cache/"
),
lambda
:
os
.
getenv
(
"VLLM_XLA_CACHE_PATH"
,
"~/.vllm/xla_cache/"
),
"VLLM_FUSED_MOE_CHUNK_SIZE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"65536"
)),
}
}
# end-env-vars-definition
# end-env-vars-definition
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
12a59959
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -420,13 +421,12 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -420,13 +421,12 @@ def fused_experts(hidden_states: torch.Tensor,
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
]
M
,
_
=
hidden_states
.
shape
num_tokens
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
E
,
N
,
_
=
w1
.
shape
# We execute the fused_moe kernel in chunks to circumvent this issue:
if
M
>
65536
:
# https://github.com/vllm-project/vllm/issues/5938
# https://github.com/vllm-project/vllm/issues/5938
raise
ValueError
(
"MoE kernel does not support more than 65536 tokens, "
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
f
"but got
{
M
}
"
)
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
if
override_config
:
if
override_config
:
config
=
override_config
config
=
override_config
...
@@ -455,18 +455,43 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -455,18 +455,43 @@ def fused_experts(hidden_states: torch.Tensor,
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
dtype
=
hidden_states
.
dtype
)
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
)
compute_type
=
(
tl
.
bfloat16
compute_type
=
(
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
)
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
)
invoke_fused_moe_kernel
(
hidden_states
,
if
inplace
:
out_hidden_states
=
hidden_states
else
:
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
for
chunk
in
range
((
num_tokens
//
CHUNK_SIZE
)
+
1
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
))
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
if
tokens_in_chunk
==
0
:
break
if
tokens_in_chunk
<
CHUNK_SIZE
:
# will only happen in the last chunk
intermediate_cache1
=
intermediate_cache1
[:
tokens_in_chunk
]
intermediate_cache2
=
intermediate_cache2
[:
tokens_in_chunk
]
intermediate_cache3
=
intermediate_cache3
[:
tokens_in_chunk
]
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
))
invoke_fused_moe_kernel
(
curr_hidden_states
,
w1
,
w1
,
intermediate_cache1
,
intermediate_cache1
,
a1_scale
,
a1_scale
,
w1_scale
,
w1_scale
,
topk_weights
,
curr_
topk_weights
,
topk_ids
,
curr_
topk_ids
,
sorted_token_ids
,
sorted_token_ids
,
expert_ids
,
expert_ids
,
num_tokens_post_padded
,
num_tokens_post_padded
,
...
@@ -483,8 +508,8 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -483,8 +508,8 @@ def fused_experts(hidden_states: torch.Tensor,
intermediate_cache3
,
intermediate_cache3
,
a2_scale
,
a2_scale
,
w2_scale
,
w2_scale
,
topk_weights
,
curr_
topk_weights
,
topk_ids
,
curr_
topk_ids
,
sorted_token_ids
,
sorted_token_ids
,
expert_ids
,
expert_ids
,
num_tokens_post_padded
,
num_tokens_post_padded
,
...
@@ -494,12 +519,10 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -494,12 +519,10 @@ def fused_experts(hidden_states: torch.Tensor,
compute_type
=
compute_type
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
use_fp8
=
use_fp8
)
if
inplace
:
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
,
dim
=
1
,
out
=
hidden_states
)
out
=
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
return
out_hidden_states
dim
=
1
)
def
fused_moe
(
def
fused_moe
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment