Commit 333104ab authored by jujl1's avatar jujl1
Browse files

feat:新增VLLM_USE_GLOBAL_CACHE13 设置moe使用全局变量的cache13

parent e92bb9ea
...@@ -163,7 +163,7 @@ if TYPE_CHECKING: ...@@ -163,7 +163,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_MOE_FUSED_GATE: bool = False VLLM_ENABLE_MOE_FUSED_GATE: bool = False
VLLM_USE_FLASH_ATTN_PA: bool = False VLLM_USE_FLASH_ATTN_PA: bool = False
VLLM_USE_APEX_RN: bool = False VLLM_USE_APEX_RN: bool = False
VLLM_USE_GLOBAL_CACHE13: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1085,6 +1085,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1085,6 +1085,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_APEX_RN": "VLLM_USE_APEX_RN":
lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use global cache for moe
"VLLM_USE_GLOBAL_CACHE13":
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "True").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -44,6 +44,14 @@ from vllm.utils import direct_register_custom_op ...@@ -44,6 +44,14 @@ from vllm.utils import direct_register_custom_op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled # from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger = init_logger(__name__) logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None
def get_moe_cache(top_k_num,N,K,device,dtype):
global moe_cache_singleton
if moe_cache_singleton is None:
moe_cache_singleton = torch.empty(envs.VLLM_FUSED_MOE_CHUNK_SIZE * top_k_num *max(N, K), device=device, dtype=dtype)
logger.info(f"Initializing moe_cache_singleton shape: {moe_cache_singleton.shape}, memory: {moe_cache_singleton.element_size() * moe_cache_singleton.numel() / 1024**2:.2f} MB")
return moe_cache_singleton
@triton.jit @triton.jit
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
...@@ -1494,13 +1502,32 @@ def fused_experts_impl( ...@@ -1494,13 +1502,32 @@ def fused_experts_impl(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
# Check constraints. num_tokens = hidden_states.size(0)
if use_nn_moe:
E, _, N = w1.size()
else:
E, N, _ = w1.size()
K = w2.size(1)
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num, N,K if not use_nn_moe else w2.shape[2], device=hidden_states.device, dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype)
if use_int8_w8a8 is True: if use_int8_w8a8 is True:
return fused_experts_impl_int8(hidden_states=hidden_states, return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1, w1=w1,
w2=w2, w2=w2,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
cache13 = cache13,
inplace=inplace, inplace=inplace,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
...@@ -1527,6 +1554,7 @@ def fused_experts_impl( ...@@ -1527,6 +1554,7 @@ def fused_experts_impl(
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=inplace, inplace=inplace,
cache13 = cache13,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8= False, use_fp8_w8a8= False,
...@@ -1565,21 +1593,6 @@ def fused_experts_impl( ...@@ -1565,21 +1593,6 @@ def fused_experts_impl(
torch.float32, torch.float16, torch.bfloat16 torch.float32, torch.float16, torch.bfloat16
] ]
num_tokens = hidden_states.size(0)
if use_nn_moe:
E, _, N = w1.size()
else:
E, N, _ = w1.size()
K = w2.size(1)
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
...@@ -1606,9 +1619,6 @@ def fused_experts_impl( ...@@ -1606,9 +1619,6 @@ def fused_experts_impl(
# We can reuse the memory between these because by the time we need # We can reuse the memory between these because by the time we need
# cache3, we're done with cache1 # cache3, we're done with cache1
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N) intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N)
intermediate_cache3 = cache13[:M * top_k_num * (K if not use_nn_moe else w2.shape[2])].view(M, top_k_num, K if not use_nn_moe else w2.shape[2]) intermediate_cache3 = cache13[:M * top_k_num * (K if not use_nn_moe else w2.shape[2])].view(M, top_k_num, K if not use_nn_moe else w2.shape[2])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment