"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a91a273d0b6c1b47be5005e4cf163adcb36d1176"
Unverified Commit 12cad0fe authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify the interface of tp_worker (#1718)

parent b6cd9036
...@@ -91,6 +91,7 @@ class Scheduler: ...@@ -91,6 +91,7 @@ class Scheduler:
port_args: PortArgs, port_args: PortArgs,
gpu_id: int, gpu_id: int,
tp_rank: int, tp_rank: int,
dp_rank: Optional[int],
): ):
# Parse args # Parse args
self.server_args = server_args self.server_args = server_args
...@@ -144,13 +145,24 @@ class Scheduler: ...@@ -144,13 +145,24 @@ class Scheduler:
# Launch a tensor parallel worker # Launch a tensor parallel worker
self.tp_worker = TpModelWorker( self.tp_worker = TpModelWorker(
server_args=server_args,
gpu_id=gpu_id, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
server_args=server_args, dp_rank=dp_rank,
nccl_port=port_args.nccl_port, nccl_port=port_args.nccl_port,
) )
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
self.device = self.tp_worker.device # Init states for overlap schedule
if self.server_args.enable_overlap_schedule:
self.forward_batch_generation = (
self.tp_worker.forward_batch_generation_non_blocking
)
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
else:
self.forward_batch_generation = self.tp_worker.forward_batch_generation
self.resolve_next_token_ids = lambda bid, x: x.tolist()
# Get token and memory info from the model worker # Get token and memory info from the model worker
( (
...@@ -159,11 +171,11 @@ class Scheduler: ...@@ -159,11 +171,11 @@ class Scheduler:
self.max_running_requests, self.max_running_requests,
self.max_req_input_len, self.max_req_input_len,
self.random_seed, self.random_seed,
self.device,
) = self.tp_worker.get_token_and_memory_info() ) = self.tp_worker.get_token_and_memory_info()
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
set_random_seed(self.random_seed) set_random_seed(self.random_seed)
self.pad_input_ids_func = getattr(
self.tp_worker.model_runner.model, "pad_input_ids", None
)
# Print debug info # Print debug info
logger.info( logger.info(
...@@ -173,9 +185,8 @@ class Scheduler: ...@@ -173,9 +185,8 @@ class Scheduler:
f"context_len={self.model_config.context_len}" f"context_len={self.model_config.context_len}"
) )
# Init cache # Init memory pool and cache
self.req_to_token_pool = self.tp_worker.model_runner.req_to_token_pool self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
self.token_to_kv_pool = self.tp_worker.model_runner.token_to_kv_pool
if ( if (
server_args.chunked_prefill_size is not None server_args.chunked_prefill_size is not None
...@@ -253,20 +264,6 @@ class Scheduler: ...@@ -253,20 +264,6 @@ class Scheduler:
with_stack=True, with_stack=True,
) )
# Init states for overlap schedule
if self.server_args.enable_overlap_schedule:
self.forward_batch_generation = (
self.tp_worker.forward_batch_generation_non_blocking
)
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
self.cache_finished_req = self.tree_cache.cache_finished_req
else:
self.forward_batch_generation = self.tp_worker.forward_batch_generation
self.resolve_next_token_ids = lambda bid, x: x.tolist()
self.cache_finished_req = self.tree_cache.cache_finished_req
@torch.inference_mode() @torch.inference_mode()
def event_loop_normal(self): def event_loop_normal(self):
self.last_batch = None self.last_batch = None
...@@ -779,7 +776,7 @@ class Scheduler: ...@@ -779,7 +776,7 @@ class Scheduler:
req.check_finished() req.check_finished()
if req.finished(): if req.finished():
self.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
elif not batch.decoding_reqs or req not in batch.decoding_reqs: elif not batch.decoding_reqs or req not in batch.decoding_reqs:
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
...@@ -808,7 +805,7 @@ class Scheduler: ...@@ -808,7 +805,7 @@ class Scheduler:
req.check_finished() req.check_finished()
if req.finished(): if req.finished():
self.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
else: else:
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
...@@ -845,7 +842,7 @@ class Scheduler: ...@@ -845,7 +842,7 @@ class Scheduler:
) )
if req.finished(): if req.finished():
self.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
if req.return_logprob: if req.return_logprob:
req.output_token_logprobs.append( req.output_token_logprobs.append(
...@@ -1069,7 +1066,7 @@ class Scheduler: ...@@ -1069,7 +1066,7 @@ class Scheduler:
for req in self.running_batch.reqs: for req in self.running_batch.reqs:
if req.rid == recv_req.rid and not req.finished(): if req.rid == recv_req.rid and not req.finished():
req.finished_reason = FINISH_ABORT() req.finished_reason = FINISH_ABORT()
self.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
break break
def update_weights(self, recv_req: UpdateWeightReqInput): def update_weights(self, recv_req: UpdateWeightReqInput):
...@@ -1112,7 +1109,7 @@ def run_scheduler_process( ...@@ -1112,7 +1109,7 @@ def run_scheduler_process(
suppress_other_loggers() suppress_other_loggers()
try: try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank) scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
pipe_writer.send("ready") pipe_writer.send("ready")
if server_args.enable_overlap_schedule: if server_args.enable_overlap_schedule:
scheduler.event_loop_overlap() scheduler.event_loop_overlap()
......
...@@ -20,6 +20,7 @@ import logging ...@@ -20,6 +20,7 @@ import logging
import threading import threading
import time import time
from queue import Queue from queue import Queue
from typing import Optional
import torch import torch
...@@ -40,9 +41,10 @@ class TpModelWorker: ...@@ -40,9 +41,10 @@ class TpModelWorker:
def __init__( def __init__(
self, self,
server_args: ServerArgs,
gpu_id: int, gpu_id: int,
tp_rank: int, tp_rank: int,
server_args: ServerArgs, dp_rank: Optional[int],
nccl_port: int, nccl_port: int,
): ):
# Parse args # Parse args
...@@ -116,6 +118,19 @@ class TpModelWorker: ...@@ -116,6 +118,19 @@ class TpModelWorker:
self.max_running_requests, self.max_running_requests,
self.max_req_input_len, self.max_req_input_len,
self.random_seed, self.random_seed,
self.device,
)
def get_pad_input_ids_func(self):
return getattr(self.model_runner.model, "pad_input_ids", None)
def get_tp_cpu_group(self):
return self.model_runner.tp_group.cpu_group
def get_memory_pool(self):
return (
self.model_runner.req_to_token_pool,
self.model_runner.token_to_kv_pool,
) )
def init_overlap_status(self): def init_overlap_status(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment