Commit 76ec56bd authored by jujl1's avatar jujl1
Browse files

feat: pp balance

parent b8f555af
......@@ -291,6 +291,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT: bool = False
VLLM_W8A8_BACKEND: int = 3
VLLM_USE_PP_BALANCE = True
VLLM_REJECT_SAMPLE_OPT: bool = False
......@@ -1790,8 +1791,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
("true", "1")),
# vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT":
# vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT":
lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "0"))),
# vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT":
......@@ -1831,6 +1832,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_FILL_RMS_CAT":
lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in
("true", "1")),
"VLLM_USE_PP_BALANCE":
lambda: (os.environ.get("VLLM_USE_PP_BALANCE", "True").lower() in
("true", "1")),
# W8A8 GEMM backend selection for vLLM quantized models.
# lightop/triton: 1
# cutlass: 2 (will remove in the future)
......
......@@ -343,7 +343,10 @@ class Scheduler(SchedulerInterface):
# For logging.
scheduled_timestamp = time.monotonic()
if self.use_pp and envs.VLLM_USE_PP_BALANCE:
pipeline_size = self.parallel_config.pipeline_parallel_size
max_batch_running = (len(self.waiting) + len(self.running)
+ pipeline_size - 1 ) // pipeline_size
# First, schedule the RUNNING requests.
req_index = 0
while req_index < len(self.running) and token_budget > 0:
......@@ -352,9 +355,14 @@ class Scheduler(SchedulerInterface):
# do not schedule another step for the same request while it still has
# output placeholders for PP.
# TODO: support PP + async scheduling without this limit
if self.use_pp and request.num_output_placeholders > 0:
req_index += 1
continue
if self.use_pp:
if (envs.VLLM_USE_PP_BALANCE and
len(scheduled_new_reqs) + len(scheduled_resumed_reqs)
+ len(scheduled_running_reqs) >= max_batch_running):
break
if request.num_output_placeholders > 0:
req_index += 1
continue
if (
request.num_output_placeholders > 0
......@@ -543,7 +551,10 @@ class Scheduler(SchedulerInterface):
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
break
if (self.use_pp and envs.VLLM_USE_PP_BALANCE and
len(scheduled_new_reqs) + len(scheduled_resumed_reqs)
+ len(scheduled_running_reqs) >= max_batch_running):
break
request = self.waiting.peek_request()
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment