Commit 2bd4a707 authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_PD_SPLIT to split prefill and decode

parent 0eaf8026
...@@ -176,6 +176,7 @@ if TYPE_CHECKING: ...@@ -176,6 +176,7 @@ if TYPE_CHECKING:
VLLM_P2P_ASYNC: bool = False VLLM_P2P_ASYNC: bool = False
VLLM_P2P_BUF_TOKENS: int = 30000 VLLM_P2P_BUF_TOKENS: int = 30000
VLLM_ENABLE_OUTPUT_PLACEHOLDERS: bool = False VLLM_ENABLE_OUTPUT_PLACEHOLDERS: bool = False
VLLM_USE_PD_SPLIT: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1092,7 +1093,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1092,7 +1093,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_ATTN_PA": "VLLM_USE_FLASH_ATTN_PA":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "True").lower() in lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use apex for rmsnorm # vLLM will use apex for rmsnorm
"VLLM_USE_APEX_RN": "VLLM_USE_APEX_RN":
lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in
...@@ -1146,6 +1146,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1146,6 +1146,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm will enable output placeholders # vllm will enable output placeholders
"VLLM_ENABLE_OUTPUT_PLACEHOLDERS": "VLLM_ENABLE_OUTPUT_PLACEHOLDERS":
lambda: bool(int(os.getenv("VLLM_ENABLE_OUTPUT_PLACEHOLDERS", "0"))), lambda: bool(int(os.getenv("VLLM_ENABLE_OUTPUT_PLACEHOLDERS", "0"))),
# vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT":
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -1014,7 +1014,7 @@ class Scheduler(SchedulerInterface): ...@@ -1014,7 +1014,7 @@ class Scheduler(SchedulerInterface):
return scheduler_output return scheduler_output
def schedule(self) -> SchedulerOutput: def schedule(self) -> SchedulerOutput:
if self.num_spec_tokens > 0: if self.num_spec_tokens > 0 or envs.VLLM_USE_PD_SPLIT:
return self.schedule_split_pd() return self.schedule_split_pd()
else: else:
return self.schedule_default() return self.schedule_default()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment