Commit d8ea775f authored by jujl1's avatar jujl1
Browse files

feat: pipeline_parallel新增pp域请求数均衡,VLLM_USE_PP_BALANCE控制,默认开启

parent 21d22cbd
...@@ -181,6 +181,7 @@ if TYPE_CHECKING: ...@@ -181,6 +181,7 @@ if TYPE_CHECKING:
VLLM_USE_PP_SYNC: bool = False VLLM_USE_PP_SYNC: bool = False
VLLM_USE_LIGHTOP_FILL_MOE_ALIGN: bool = False VLLM_USE_LIGHTOP_FILL_MOE_ALIGN: bool = False
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT: bool = False USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT: bool = False
VLLM_USE_PP_BALANCE: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1171,6 +1172,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1171,6 +1172,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT": "USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT', '0').lower() in lambda: (os.getenv('USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT', '0').lower() in
("true", "1")), ("true", "1")),
"VLLM_USE_PP_BALANCE":
lambda: (os.getenv('VLLM_USE_PP_BALANCE', '1').lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -213,9 +213,15 @@ class Scheduler(SchedulerInterface): ...@@ -213,9 +213,15 @@ class Scheduler(SchedulerInterface):
# First, schedule the RUNNING requests. # First, schedule the RUNNING requests.
req_index = 0 req_index = 0
if envs.VLLM_USE_PP_BALANCE and self.use_pp:
pipeline_size = self.parallel_config.pipeline_parallel_size
max_batch_running = (len(self.waiting) + len(self.running) + pipeline_size - 1 ) // pipeline_size
while req_index < len(self.running) and token_budget > 0: while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index] request = self.running[req_index]
if (envs.VLLM_USE_PP_BALANCE and self.use_pp and
len(scheduled_new_reqs) + len(scheduled_resumed_reqs)
+ len(scheduled_running_reqs) >= max_batch_running):
break
num_new_tokens = (request.num_tokens_with_spec - num_new_tokens = (request.num_tokens_with_spec -
request.num_computed_tokens) request.num_computed_tokens)
...@@ -359,6 +365,11 @@ class Scheduler(SchedulerInterface): ...@@ -359,6 +365,11 @@ class Scheduler(SchedulerInterface):
if len(self.running) == self.max_num_running_reqs: if len(self.running) == self.max_num_running_reqs:
break break
if (envs.VLLM_USE_PP_BALANCE and self.use_pp and
len(scheduled_new_reqs) + len(scheduled_resumed_reqs)
+ len(scheduled_running_reqs) >= max_batch_running):
break
request = self.waiting.peek_request() request = self.waiting.peek_request()
if request.is_finished(): if request.is_finished():
self.waiting.pop_request() self.waiting.pop_request()
...@@ -643,11 +654,18 @@ class Scheduler(SchedulerInterface): ...@@ -643,11 +654,18 @@ class Scheduler(SchedulerInterface):
skipped_waiting_requests = create_request_queue(self.policy) skipped_waiting_requests = create_request_queue(self.policy)
req_index = len(self.running) req_index = len(self.running)
if envs.VLLM_USE_PP_BALANCE and self.use_pp:
pipeline_size = self.parallel_config.pipeline_parallel_size
max_batch_running = (len(self.waiting) + len(self.running) + pipeline_size - 1 ) // pipeline_size
# First, schedule the WAITING requests. # First, schedule the WAITING requests.
while self.waiting and token_budget > 0: while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs: if len(self.running) == self.max_num_running_reqs:
break break
#TODO:考虑到decode过程中来新请求时,可以一次性处理所有请求的prefill 也许schedule the WAITING requests 中取消pp平衡效果更好
if (envs.VLLM_USE_PP_BALANCE and self.use_pp and
len(scheduled_new_reqs) + len(scheduled_resumed_reqs)
+ len(scheduled_running_reqs) >= max_batch_running):
break
request = self.waiting.peek_request() request = self.waiting.peek_request()
# KVTransfer: skip request if still waiting for remote kvs. # KVTransfer: skip request if still waiting for remote kvs.
...@@ -828,6 +846,10 @@ class Scheduler(SchedulerInterface): ...@@ -828,6 +846,10 @@ class Scheduler(SchedulerInterface):
if not scheduled_new_reqs and not scheduled_resumed_reqs: if not scheduled_new_reqs and not scheduled_resumed_reqs:
req_index = 0 req_index = 0
while req_index < len(self.running) and token_budget > 0: while req_index < len(self.running) and token_budget > 0:
if (envs.VLLM_USE_PP_BALANCE and self.use_pp and
len(scheduled_new_reqs) + len(scheduled_resumed_reqs)
+ len(scheduled_running_reqs) >= max_batch_running):
break
request = self.running[req_index] request = self.running[req_index]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment