Commit cff5c2d2 authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_LIGHTOP_FILL_MOE_ALIN

parent ed53dfb0
...@@ -179,6 +179,7 @@ if TYPE_CHECKING: ...@@ -179,6 +179,7 @@ if TYPE_CHECKING:
VLLM_SCHED_ENABLE_MINIMAL_INJECTION: bool = False VLLM_SCHED_ENABLE_MINIMAL_INJECTION: bool = False
VLLM_USE_PD_SPLIT: bool = False VLLM_USE_PD_SPLIT: bool = False
VLLM_USE_PP_SYNC: bool = False VLLM_USE_PP_SYNC: bool = False
VLLM_USE_LIGHTOP_FILL_MOE_ALIN: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1161,6 +1162,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1161,6 +1162,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_PP_SYNC": "VLLM_USE_PP_SYNC":
lambda: (os.environ.get("VLLM_USE_PP_SYNC", "False").lower() in lambda: (os.environ.get("VLLM_USE_PP_SYNC", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop to fuse fill and moe align
"VLLM_USE_LIGHTOP_FILL_MOE_ALIN":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_FILL_MOE_ALIN", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -216,7 +216,9 @@ def moe_align_block_size( ...@@ -216,7 +216,9 @@ def moe_align_block_size(
sorted_ids = torch.empty((max_num_tokens_padded, ), sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32, dtype=torch.int32,
device=topk_ids.device) device=topk_ids.device)
if not envs.VLLM_USE_LIGHTOP_FILL_MOE_ALIN:
sorted_ids.fill_(topk_ids.numel()) sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be zeroed out to prevent index out of bounds error while # Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism. # mapping global expert ids to local expert ids in expert parallelism.
......
...@@ -251,6 +251,8 @@ def get_model_architecture( ...@@ -251,6 +251,8 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_FILL_MOE_ALIN"):
os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIN'] = '1'
if os.getenv('GEMM_PAD') != '1': if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment