Commit 865dc64d authored by zhuwenwen's avatar zhuwenwen
Browse files

[feat]新增VLLM_USE_GLOBAL_CACHE13 设置moe使用全局变量的cache13

parent 0f12f80a
......@@ -168,6 +168,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_MOE_FUSED_GATE: bool = False
VLLM_USE_FLASH_ATTN_PA: bool = False
VLLM_USE_APEX_RN: bool = False
VLLM_USE_GLOBAL_CACHE13: bool = False
def get_default_cache_root():
......@@ -1116,6 +1117,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_APEX_RN":
lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in
("true", "1")),
# vLLM will use global cache for moe
"VLLM_USE_GLOBAL_CACHE13":
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "True").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -51,6 +51,16 @@ from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None
def get_moe_cache(top_k_num,N,K,device,dtype):
global moe_cache_singleton
if moe_cache_singleton is None:
moe_cache_singleton = torch.empty(envs.VLLM_FUSED_MOE_CHUNK_SIZE * top_k_num *max(N, K), device=device, dtype=dtype)
logger.info(f"Initializing moe_cache_singleton shape: {moe_cache_singleton.shape}, memory: {moe_cache_singleton.element_size() * moe_cache_singleton.numel() / 1024**2:.2f} MB")
return moe_cache_singleton
@triton.jit
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N,
......@@ -1618,12 +1628,33 @@ def fused_experts_impl(
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor:
# Check constraints.
num_tokens = hidden_states.size(0)
if use_nn_moe:
E, _, N = w1.size()
else:
E, N, _ = w1.size()
K = w2.size(1)
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num, N,K if not use_nn_moe else w2.shape[2], device=hidden_states.device, dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype)
if use_int8_w8a8 is True:
return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
......@@ -1650,6 +1681,7 @@ def fused_experts_impl(
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
cache13=cache13,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=False,
......@@ -1690,21 +1722,6 @@ def fused_experts_impl(
torch.float32, torch.float16, torch.bfloat16
]
num_tokens = hidden_states.size(0)
if use_nn_moe:
E, _, N = w1.size()
else:
E, N, _ = w1.size()
K = w2.size(1)
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
......@@ -1733,9 +1750,6 @@ def fused_experts_impl(
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N)
intermediate_cache3 = cache13[:M * top_k_num * (K if not use_nn_moe else w2.shape[2])].view(M, top_k_num, K if not use_nn_moe else w2.shape[2])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment