Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
333104ab
Commit
333104ab
authored
Jul 31, 2025
by
jujl1
Browse files
feat:新增VLLM_USE_GLOBAL_CACHE13 设置moe使用全局变量的cache13
parent
e92bb9ea
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
20 deletions
+35
-20
vllm/envs.py
vllm/envs.py
+6
-1
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+29
-19
No files found.
vllm/envs.py
View file @
333104ab
...
...
@@ -163,7 +163,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_MOE_FUSED_GATE
:
bool
=
False
VLLM_USE_FLASH_ATTN_PA
:
bool
=
False
VLLM_USE_APEX_RN
:
bool
=
False
VLLM_USE_GLOBAL_CACHE13
:
bool
=
False
def
get_default_cache_root
():
return
os
.
getenv
(
...
...
@@ -1085,6 +1085,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_APEX_RN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_APEX_RN"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use global cache for moe
"VLLM_USE_GLOBAL_CACHE13"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_GLOBAL_CACHE13"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
333104ab
...
...
@@ -44,6 +44,14 @@ from vllm.utils import direct_register_custom_op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger
=
init_logger
(
__name__
)
if
envs
.
VLLM_USE_GLOBAL_CACHE13
:
moe_cache_singleton
=
None
def
get_moe_cache
(
top_k_num
,
N
,
K
,
device
,
dtype
):
global
moe_cache_singleton
if
moe_cache_singleton
is
None
:
moe_cache_singleton
=
torch
.
empty
(
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
*
top_k_num
*
max
(
N
,
K
),
device
=
device
,
dtype
=
dtype
)
logger
.
info
(
f
"Initializing moe_cache_singleton shape:
{
moe_cache_singleton
.
shape
}
, memory:
{
moe_cache_singleton
.
element_size
()
*
moe_cache_singleton
.
numel
()
/
1024
**
2
:.
2
f
}
MB"
)
return
moe_cache_singleton
@
triton
.
jit
def
write_zeros_to_output
(
c_ptr
,
stride_cm
,
stride_cn
,
pid_n
,
N
,
offs_token
,
...
...
@@ -1494,13 +1502,32 @@ def fused_experts_impl(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
# Check constraints.
num_tokens
=
hidden_states
.
size
(
0
)
if
use_nn_moe
:
E
,
_
,
N
=
w1
.
size
()
else
:
E
,
N
,
_
=
w1
.
size
()
K
=
w2
.
size
(
1
)
if
global_num_experts
==
-
1
:
global_num_experts
=
E
top_k_num
=
topk_ids
.
size
(
1
)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
if
envs
.
VLLM_USE_GLOBAL_CACHE13
:
cache13
=
get_moe_cache
(
top_k_num
,
N
,
K
if
not
use_nn_moe
else
w2
.
shape
[
2
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
else
:
cache13
=
torch
.
empty
(
M
*
top_k_num
*
max
(
N
,
K
if
not
use_nn_moe
else
w2
.
shape
[
2
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
if
use_int8_w8a8
is
True
:
return
fused_experts_impl_int8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
cache13
=
cache13
,
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
...
...
@@ -1527,6 +1554,7 @@ def fused_experts_impl(
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
,
cache13
=
cache13
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
False
,
...
...
@@ -1565,21 +1593,6 @@ def fused_experts_impl(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
num_tokens
=
hidden_states
.
size
(
0
)
if
use_nn_moe
:
E
,
_
,
N
=
w1
.
size
()
else
:
E
,
N
,
_
=
w1
.
size
()
K
=
w2
.
size
(
1
)
if
global_num_experts
==
-
1
:
global_num_experts
=
E
top_k_num
=
topk_ids
.
size
(
1
)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
...
...
@@ -1606,9 +1619,6 @@ def fused_experts_impl(
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
cache13
=
torch
.
empty
(
M
*
top_k_num
*
max
(
N
,
K
if
not
use_nn_moe
else
w2
.
shape
[
2
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache1
=
cache13
[:
M
*
top_k_num
*
N
].
view
(
M
,
top_k_num
,
N
)
intermediate_cache3
=
cache13
[:
M
*
top_k_num
*
(
K
if
not
use_nn_moe
else
w2
.
shape
[
2
])].
view
(
M
,
top_k_num
,
K
if
not
use_nn_moe
else
w2
.
shape
[
2
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment