"git@developer.sourcefind.cn:change/sglang.git" did not exist on "c245b78973c934752b5d3b73f0bb62047b1c4f3d"
Unverified Commit 59cbf476 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Unify the memory pool api and tp worker API (#1724)

parent 95946271
...@@ -23,6 +23,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch ...@@ -23,6 +23,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`. - ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU. It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`. - ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
It will be transformed from CPU scheduler to GPU model runner.
- ForwardBatch is managed by `model_runner.py::ModelRunner`. - ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors. It contains low-level tensor data. Most of the data consists of GPU tensors.
""" """
...@@ -522,12 +524,12 @@ class ScheduleBatch: ...@@ -522,12 +524,12 @@ class ScheduleBatch:
assert seq_len - pre_len == req.extend_input_len assert seq_len - pre_len == req.extend_input_len
if pre_len > 0: if pre_len > 0:
self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = ( self.req_to_token_pool.write(
req.prefix_indices (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
) )
self.req_to_token_pool.write(
self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = ( (req.req_pool_idx, slice(pre_len, seq_len)),
out_cache_loc[pt : pt + req.extend_input_len] out_cache_loc[pt : pt + req.extend_input_len],
) )
# Compute the relative logprob_start_len in an extend batch # Compute the relative logprob_start_len in an extend batch
...@@ -765,9 +767,8 @@ class ScheduleBatch: ...@@ -765,9 +767,8 @@ class ScheduleBatch:
# Alloc mem # Alloc mem
bs = len(self.reqs) bs = len(self.reqs)
self.out_cache_loc = self.alloc_token_slots(bs) self.out_cache_loc = self.alloc_token_slots(bs)
self.req_to_token_pool.write(
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = ( (self.req_pool_indices, self.seq_lens), self.out_cache_loc
self.out_cache_loc
) )
self.seq_lens.add_(1) self.seq_lens.add_(1)
...@@ -848,7 +849,6 @@ class ScheduleBatch: ...@@ -848,7 +849,6 @@ class ScheduleBatch:
extend_logprob_start_lens = self.extend_logprob_start_lens extend_logprob_start_lens = self.extend_logprob_start_lens
image_inputs = [r.image_inputs for r in self.reqs] image_inputs = [r.image_inputs for r in self.reqs]
lora_paths = [req.lora_path for req in self.reqs]
if self.has_regex: if self.has_regex:
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs] self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
self.sampling_info.regex_fsm_states = [ self.sampling_info.regex_fsm_states = [
...@@ -869,13 +869,14 @@ class ScheduleBatch: ...@@ -869,13 +869,14 @@ class ScheduleBatch:
req_pool_indices=self.req_pool_indices, req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens, seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc, out_cache_loc=self.out_cache_loc,
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
return_logprob=self.return_logprob, return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums, top_logprobs_nums=self.top_logprobs_nums,
extend_seq_lens=extend_seq_lens, extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens, extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens, extend_logprob_start_lens=extend_logprob_start_lens,
image_inputs=image_inputs, image_inputs=image_inputs,
lora_paths=lora_paths, lora_paths=[req.lora_path for req in self.reqs],
sampling_info=self.sampling_info, sampling_info=self.sampling_info,
mrope_positions_delta=mrope_positions_delta, mrope_positions_delta=mrope_positions_delta,
) )
...@@ -911,6 +912,9 @@ class ModelWorkerBatch: ...@@ -911,6 +912,9 @@ class ModelWorkerBatch:
# The indices of output tokens in the token_to_kv_pool # The indices of output tokens in the token_to_kv_pool
out_cache_loc: torch.Tensor out_cache_loc: torch.Tensor
# The memory pool operation records
req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
# For logprob # For logprob
return_logprob: bool return_logprob: bool
top_logprobs_nums: Optional[List[int]] top_logprobs_nums: Optional[List[int]]
...@@ -940,6 +944,7 @@ class ModelWorkerBatch: ...@@ -940,6 +944,7 @@ class ModelWorkerBatch:
req_pool_indices=self.req_pool_indices, req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens.clone(), seq_lens=self.seq_lens.clone(),
out_cache_loc=self.out_cache_loc, out_cache_loc=self.out_cache_loc,
req_to_token_pool_records=self.req_to_token_pool_records,
return_logprob=self.return_logprob, return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums, top_logprobs_nums=self.top_logprobs_nums,
extend_seq_lens=self.extend_seq_lens, extend_seq_lens=self.extend_seq_lens,
...@@ -950,3 +955,14 @@ class ModelWorkerBatch: ...@@ -950,3 +955,14 @@ class ModelWorkerBatch:
sampling_info=self.sampling_info.copy(), sampling_info=self.sampling_info.copy(),
mrope_positions_delta=self.mrope_positions_delta, mrope_positions_delta=self.mrope_positions_delta,
) )
def to(self, device: str):
self.input_ids = self.input_ids.to(device, non_blocking=True)
self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
self.seq_lens = self.seq_lens.to(device, non_blocking=True)
self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
self.req_to_token_pool_records = [
(x, y.to(device, non_blocking=True))
for x, y in self.req_to_token_pool_records
]
self.sampling_info.to(device)
...@@ -51,6 +51,7 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -51,6 +51,7 @@ from sglang.srt.managers.schedule_batch import (
ImageInputs, ImageInputs,
Req, Req,
ScheduleBatch, ScheduleBatch,
global_server_args_dict,
) )
from sglang.srt.managers.schedule_policy import ( from sglang.srt.managers.schedule_policy import (
AddReqResult, AddReqResult,
...@@ -144,25 +145,27 @@ class Scheduler: ...@@ -144,25 +145,27 @@ class Scheduler:
) )
# Launch a tensor parallel worker # Launch a tensor parallel worker
self.tp_worker = TpModelWorker( if self.server_args.enable_overlap_schedule:
TpWorkerClass = TpModelWorker
else:
TpWorkerClass = TpModelWorker
self.tp_worker = TpWorkerClass(
server_args=server_args, server_args=server_args,
gpu_id=gpu_id, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
dp_rank=dp_rank, dp_rank=dp_rank,
nccl_port=port_args.nccl_port, nccl_port=port_args.nccl_port,
) )
# Init states for overlap schedule
if self.server_args.enable_overlap_schedule: if self.server_args.enable_overlap_schedule:
self.forward_batch_generation = (
self.tp_worker.forward_batch_generation_non_blocking
)
self.resolve_next_token_ids = ( self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid) lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
) )
self.forward_batch_generation = (
self.tp_worker.forward_batch_generation_non_blocking
)
else: else:
self.forward_batch_generation = self.tp_worker.forward_batch_generation
self.resolve_next_token_ids = lambda bid, x: x.tolist() self.resolve_next_token_ids = lambda bid, x: x.tolist()
self.forward_batch_generation = self.tp_worker.forward_batch_generation
# Get token and memory info from the model worker # Get token and memory info from the model worker
( (
...@@ -172,9 +175,14 @@ class Scheduler: ...@@ -172,9 +175,14 @@ class Scheduler:
self.max_req_input_len, self.max_req_input_len,
self.random_seed, self.random_seed,
self.device, self.device,
) = self.tp_worker.get_token_and_memory_info() worker_global_server_args_dict,
_,
_,
_,
) = self.tp_worker.get_worker_info()
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group() self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func() self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed) set_random_seed(self.random_seed)
# Print debug info # Print debug info
...@@ -266,6 +274,7 @@ class Scheduler: ...@@ -266,6 +274,7 @@ class Scheduler:
@torch.inference_mode() @torch.inference_mode()
def event_loop_normal(self): def event_loop_normal(self):
"""A normal blocking scheduler loop."""
self.last_batch = None self.last_batch = None
while True: while True:
...@@ -296,6 +305,7 @@ class Scheduler: ...@@ -296,6 +305,7 @@ class Scheduler:
@torch.inference_mode() @torch.inference_mode()
def event_loop_overlap(self): def event_loop_overlap(self):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
result_queue = deque() result_queue = deque()
self.last_batch = None self.last_batch = None
...@@ -572,6 +582,7 @@ class Scheduler: ...@@ -572,6 +582,7 @@ class Scheduler:
else set([]) else set([])
) )
# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue: for req in self.waiting_queue:
if ( if (
self.lora_paths self.lora_paths
...@@ -673,6 +684,7 @@ class Scheduler: ...@@ -673,6 +684,7 @@ class Scheduler:
return new_batch return new_batch
def update_running_batch(self): def update_running_batch(self):
"""Update the current running decoding batch."""
global test_retract global test_retract
batch = self.running_batch batch = self.running_batch
...@@ -712,6 +724,7 @@ class Scheduler: ...@@ -712,6 +724,7 @@ class Scheduler:
batch.prepare_for_decode() batch.prepare_for_decode()
def run_batch(self, batch: ScheduleBatch): def run_batch(self, batch: ScheduleBatch):
"""Run a batch."""
if self.is_generation: if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
...@@ -933,6 +946,7 @@ class Scheduler: ...@@ -933,6 +946,7 @@ class Scheduler:
return num_input_logprobs return num_input_logprobs
def stream_output(self, reqs: List[Req]): def stream_output(self, reqs: List[Req]):
"""Stream the output to detokenizer."""
output_rids = [] output_rids = []
output_meta_info = [] output_meta_info = []
output_finished_reason: List[BaseFinishReason] = [] output_finished_reason: List[BaseFinishReason] = []
...@@ -1030,6 +1044,7 @@ class Scheduler: ...@@ -1030,6 +1044,7 @@ class Scheduler:
) )
def flush_cache(self): def flush_cache(self):
"""Flush the memory pool and cache."""
if len(self.waiting_queue) == 0 and ( if len(self.waiting_queue) == 0 and (
self.running_batch is None or len(self.running_batch.reqs) == 0 self.running_batch is None or len(self.running_batch.reqs) == 0
): ):
...@@ -1070,6 +1085,7 @@ class Scheduler: ...@@ -1070,6 +1085,7 @@ class Scheduler:
break break
def update_weights(self, recv_req: UpdateWeightReqInput): def update_weights(self, recv_req: UpdateWeightReqInput):
"""In-place update of the weights."""
success, message = self.tp_worker.update_weights(recv_req) success, message = self.tp_worker.update_weights(recv_req)
if success: if success:
flash_cache_success = self.flush_cache() flash_cache_success = self.flush_cache()
......
...@@ -27,7 +27,7 @@ import torch ...@@ -27,7 +27,7 @@ import torch
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightReqInput from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -111,7 +111,7 @@ class TpModelWorker: ...@@ -111,7 +111,7 @@ class TpModelWorker:
if server_args.enable_overlap_schedule: if server_args.enable_overlap_schedule:
self.init_overlap_status() self.init_overlap_status()
def get_token_and_memory_info(self): def get_worker_info(self):
return ( return (
self.max_total_num_tokens, self.max_total_num_tokens,
self.max_prefill_tokens, self.max_prefill_tokens,
...@@ -119,6 +119,10 @@ class TpModelWorker: ...@@ -119,6 +119,10 @@ class TpModelWorker:
self.max_req_input_len, self.max_req_input_len,
self.random_seed, self.random_seed,
self.device, self.device,
global_server_args_dict,
self.model_runner.req_to_token_pool.size,
self.model_runner.req_to_token_pool.max_context_len,
self.model_runner.token_to_kv_pool.size,
) )
def get_pad_input_ids_func(self): def get_pad_input_ids_func(self):
......
...@@ -56,6 +56,12 @@ class ReqToTokenPool: ...@@ -56,6 +56,12 @@ class ReqToTokenPool:
def clear(self): def clear(self):
self.free_slots = list(range(self.size)) self.free_slots = list(range(self.size))
def write(self, indices, values):
self.req_to_token[indices] = values
def get_write_records(self):
return None
class BaseTokenToKVPool: class BaseTokenToKVPool:
"""A memory pool that maps a token to its kv cache locations""" """A memory pool that maps a token to its kv cache locations"""
...@@ -68,12 +74,12 @@ class BaseTokenToKVPool: ...@@ -68,12 +74,12 @@ class BaseTokenToKVPool:
): ):
self.size = size self.size = size
self.dtype = dtype self.dtype = dtype
self.device = device
if dtype == torch.float8_e5m2: if dtype == torch.float8_e5m2:
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2 # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8 self.store_dtype = torch.uint8
else: else:
self.store_dtype = dtype self.store_dtype = dtype
self.device = device
self.free_slots = None self.free_slots = None
self.is_not_in_free_group = True self.is_not_in_free_group = True
......
...@@ -145,9 +145,10 @@ class RadixCache(BasePrefixCache): ...@@ -145,9 +145,10 @@ class RadixCache(BasePrefixCache):
# The prefix indices could be updated, reuse it # The prefix indices could be updated, reuse it
new_indices, new_last_node = self.match_prefix(token_ids) new_indices, new_last_node = self.match_prefix(token_ids)
assert len(new_indices) == len(token_ids) assert len(new_indices) == len(token_ids)
self.req_to_token_pool.req_to_token[ self.req_to_token_pool.write(
req.req_pool_idx, len(req.prefix_indices) : len(new_indices) (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
] = new_indices[len(req.prefix_indices) :] new_indices[len(req.prefix_indices) :],
)
self.dec_lock_ref(req.last_node) self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node) self.inc_lock_ref(new_last_node)
......
...@@ -25,6 +25,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch ...@@ -25,6 +25,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`. - ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU. It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`. - ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
It will be transformed from CPU scheduler to GPU model runner.
- ForwardBatch is managed by `model_runner.py::ModelRunner`. - ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors. It contains low-level tensor data. Most of the data consists of GPU tensors.
""" """
......
...@@ -131,6 +131,13 @@ class ModelRunner: ...@@ -131,6 +131,13 @@ class ModelRunner:
]: ]:
server_args.disable_cuda_graph = True server_args.disable_cuda_graph = True
if self.server_args.enable_overlap_schedule:
logger.warning(
"Overlap scheduler is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results."
)
# Global vars # Global vars
if server_args.show_time_cost: if server_args.show_time_cost:
enable_show_time_cost() enable_show_time_cost()
......
...@@ -78,7 +78,7 @@ class SamplingBatchInfo: ...@@ -78,7 +78,7 @@ class SamplingBatchInfo:
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
is_all_greedy=top_ks.max().item() <= 1, is_all_greedy=top_ks.max().item() <= 1,
vocab_size=vocab_size, vocab_size=vocab_size,
device=batch.input_ids.device, device=device,
) )
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
...@@ -224,3 +224,13 @@ class SamplingBatchInfo: ...@@ -224,3 +224,13 @@ class SamplingBatchInfo:
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
device=self.device, device=self.device,
) )
def to(self, device: str):
for item in [
"temperatures",
"top_ps",
"top_ks",
"min_ps",
]:
value = getattr(self, item)
setattr(self, item, value.to(device, non_blocking=True))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment