"tutorials/vscode:/vscode.git/clone" did not exist on "3f6f6941598f669bf05447cc50018ef63cc7ab02"
Unverified Commit 59cbf476 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Unify the memory pool api and tp worker API (#1724)

parent 95946271
......@@ -23,6 +23,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
It will be transformed from CPU scheduler to GPU model runner.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
......@@ -522,12 +524,12 @@ class ScheduleBatch:
assert seq_len - pre_len == req.extend_input_len
if pre_len > 0:
self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
req.prefix_indices
self.req_to_token_pool.write(
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
)
self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
out_cache_loc[pt : pt + req.extend_input_len]
self.req_to_token_pool.write(
(req.req_pool_idx, slice(pre_len, seq_len)),
out_cache_loc[pt : pt + req.extend_input_len],
)
# Compute the relative logprob_start_len in an extend batch
......@@ -765,9 +767,8 @@ class ScheduleBatch:
# Alloc mem
bs = len(self.reqs)
self.out_cache_loc = self.alloc_token_slots(bs)
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
self.out_cache_loc
self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
)
self.seq_lens.add_(1)
......@@ -848,7 +849,6 @@ class ScheduleBatch:
extend_logprob_start_lens = self.extend_logprob_start_lens
image_inputs = [r.image_inputs for r in self.reqs]
lora_paths = [req.lora_path for req in self.reqs]
if self.has_regex:
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
self.sampling_info.regex_fsm_states = [
......@@ -869,13 +869,14 @@ class ScheduleBatch:
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens,
image_inputs=image_inputs,
lora_paths=lora_paths,
lora_paths=[req.lora_path for req in self.reqs],
sampling_info=self.sampling_info,
mrope_positions_delta=mrope_positions_delta,
)
......@@ -911,6 +912,9 @@ class ModelWorkerBatch:
# The indices of output tokens in the token_to_kv_pool
out_cache_loc: torch.Tensor
# The memory pool operation records
req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
# For logprob
return_logprob: bool
top_logprobs_nums: Optional[List[int]]
......@@ -940,6 +944,7 @@ class ModelWorkerBatch:
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens.clone(),
out_cache_loc=self.out_cache_loc,
req_to_token_pool_records=self.req_to_token_pool_records,
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
extend_seq_lens=self.extend_seq_lens,
......@@ -950,3 +955,14 @@ class ModelWorkerBatch:
sampling_info=self.sampling_info.copy(),
mrope_positions_delta=self.mrope_positions_delta,
)
def to(self, device: str):
self.input_ids = self.input_ids.to(device, non_blocking=True)
self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
self.seq_lens = self.seq_lens.to(device, non_blocking=True)
self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
self.req_to_token_pool_records = [
(x, y.to(device, non_blocking=True))
for x, y in self.req_to_token_pool_records
]
self.sampling_info.to(device)
......@@ -51,6 +51,7 @@ from sglang.srt.managers.schedule_batch import (
ImageInputs,
Req,
ScheduleBatch,
global_server_args_dict,
)
from sglang.srt.managers.schedule_policy import (
AddReqResult,
......@@ -144,25 +145,27 @@ class Scheduler:
)
# Launch a tensor parallel worker
self.tp_worker = TpModelWorker(
if self.server_args.enable_overlap_schedule:
TpWorkerClass = TpModelWorker
else:
TpWorkerClass = TpModelWorker
self.tp_worker = TpWorkerClass(
server_args=server_args,
gpu_id=gpu_id,
tp_rank=tp_rank,
dp_rank=dp_rank,
nccl_port=port_args.nccl_port,
)
# Init states for overlap schedule
if self.server_args.enable_overlap_schedule:
self.forward_batch_generation = (
self.tp_worker.forward_batch_generation_non_blocking
)
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
self.forward_batch_generation = (
self.tp_worker.forward_batch_generation_non_blocking
)
else:
self.forward_batch_generation = self.tp_worker.forward_batch_generation
self.resolve_next_token_ids = lambda bid, x: x.tolist()
self.forward_batch_generation = self.tp_worker.forward_batch_generation
# Get token and memory info from the model worker
(
......@@ -172,9 +175,14 @@ class Scheduler:
self.max_req_input_len,
self.random_seed,
self.device,
) = self.tp_worker.get_token_and_memory_info()
worker_global_server_args_dict,
_,
_,
_,
) = self.tp_worker.get_worker_info()
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed)
# Print debug info
......@@ -266,6 +274,7 @@ class Scheduler:
@torch.inference_mode()
def event_loop_normal(self):
"""A normal blocking scheduler loop."""
self.last_batch = None
while True:
......@@ -296,6 +305,7 @@ class Scheduler:
@torch.inference_mode()
def event_loop_overlap(self):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
result_queue = deque()
self.last_batch = None
......@@ -572,6 +582,7 @@ class Scheduler:
else set([])
)
# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if (
self.lora_paths
......@@ -673,6 +684,7 @@ class Scheduler:
return new_batch
def update_running_batch(self):
"""Update the current running decoding batch."""
global test_retract
batch = self.running_batch
......@@ -712,6 +724,7 @@ class Scheduler:
batch.prepare_for_decode()
def run_batch(self, batch: ScheduleBatch):
"""Run a batch."""
if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
model_worker_batch = batch.get_model_worker_batch()
......@@ -933,6 +946,7 @@ class Scheduler:
return num_input_logprobs
def stream_output(self, reqs: List[Req]):
"""Stream the output to detokenizer."""
output_rids = []
output_meta_info = []
output_finished_reason: List[BaseFinishReason] = []
......@@ -1030,6 +1044,7 @@ class Scheduler:
)
def flush_cache(self):
"""Flush the memory pool and cache."""
if len(self.waiting_queue) == 0 and (
self.running_batch is None or len(self.running_batch.reqs) == 0
):
......@@ -1070,6 +1085,7 @@ class Scheduler:
break
def update_weights(self, recv_req: UpdateWeightReqInput):
"""In-place update of the weights."""
success, message = self.tp_worker.update_weights(recv_req)
if success:
flash_cache_success = self.flush_cache()
......
......@@ -27,7 +27,7 @@ import torch
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
......@@ -111,7 +111,7 @@ class TpModelWorker:
if server_args.enable_overlap_schedule:
self.init_overlap_status()
def get_token_and_memory_info(self):
def get_worker_info(self):
return (
self.max_total_num_tokens,
self.max_prefill_tokens,
......@@ -119,6 +119,10 @@ class TpModelWorker:
self.max_req_input_len,
self.random_seed,
self.device,
global_server_args_dict,
self.model_runner.req_to_token_pool.size,
self.model_runner.req_to_token_pool.max_context_len,
self.model_runner.token_to_kv_pool.size,
)
def get_pad_input_ids_func(self):
......
......@@ -56,6 +56,12 @@ class ReqToTokenPool:
def clear(self):
self.free_slots = list(range(self.size))
def write(self, indices, values):
self.req_to_token[indices] = values
def get_write_records(self):
return None
class BaseTokenToKVPool:
"""A memory pool that maps a token to its kv cache locations"""
......@@ -68,12 +74,12 @@ class BaseTokenToKVPool:
):
self.size = size
self.dtype = dtype
self.device = device
if dtype == torch.float8_e5m2:
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
self.device = device
self.free_slots = None
self.is_not_in_free_group = True
......
......@@ -145,9 +145,10 @@ class RadixCache(BasePrefixCache):
# The prefix indices could be updated, reuse it
new_indices, new_last_node = self.match_prefix(token_ids)
assert len(new_indices) == len(token_ids)
self.req_to_token_pool.req_to_token[
req.req_pool_idx, len(req.prefix_indices) : len(new_indices)
] = new_indices[len(req.prefix_indices) :]
self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :],
)
self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node)
......
......@@ -25,6 +25,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
It will be transformed from CPU scheduler to GPU model runner.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
......
......@@ -131,6 +131,13 @@ class ModelRunner:
]:
server_args.disable_cuda_graph = True
if self.server_args.enable_overlap_schedule:
logger.warning(
"Overlap scheduler is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results."
)
# Global vars
if server_args.show_time_cost:
enable_show_time_cost()
......
......@@ -78,7 +78,7 @@ class SamplingBatchInfo:
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
is_all_greedy=top_ks.max().item() <= 1,
vocab_size=vocab_size,
device=batch.input_ids.device,
device=device,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
......@@ -224,3 +224,13 @@ class SamplingBatchInfo:
vocab_size=self.vocab_size,
device=self.device,
)
def to(self, device: str):
for item in [
"temperatures",
"top_ps",
"top_ks",
"min_ps",
]:
value = getattr(self, item)
setattr(self, item, value.to(device, non_blocking=True))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment