Unverified Commit 63ba2f8d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up batch data structures: Introducing ModelWorkerBatch (#1544)

parent 36d5acfc
...@@ -62,11 +62,13 @@ import torch.distributed as dist ...@@ -62,11 +62,13 @@ import torch.distributed as dist
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server import _set_envs_and_config from sglang.srt.server import _set_envs_and_config
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
allocate_init_ports,
configure_logger, configure_logger,
kill_child_process, kill_child_process,
suppress_other_loggers, suppress_other_loggers,
...@@ -125,6 +127,11 @@ def load_model(server_args, tp_rank): ...@@ -125,6 +127,11 @@ def load_model(server_args, tp_rank):
suppress_other_loggers() suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port,
server_args.additional_ports,
server_args.dp_size,
)
model_config = ModelConfig( model_config = ModelConfig(
server_args.model_path, server_args.model_path,
server_args.trust_remote_code, server_args.trust_remote_code,
...@@ -136,7 +143,7 @@ def load_model(server_args, tp_rank): ...@@ -136,7 +143,7 @@ def load_model(server_args, tp_rank):
gpu_id=tp_rank, gpu_id=tp_rank,
tp_rank=tp_rank, tp_rank=tp_rank,
tp_size=server_args.tp_size, tp_size=server_args.tp_size,
nccl_port=28888, nccl_port=server_args.additional_ports[-1],
server_args=server_args, server_args=server_args,
) )
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}") rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
...@@ -225,17 +232,19 @@ def extend(reqs, model_runner): ...@@ -225,17 +232,19 @@ def extend(reqs, model_runner):
tree_cache=None, tree_cache=None,
) )
batch.prepare_for_extend(model_runner.model_config.vocab_size) batch.prepare_for_extend(model_runner.model_config.vocab_size)
forward_batch = batch.get_forward_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch) logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, batch).tolist() next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
return next_token_ids, logits_output.next_token_logits, batch return next_token_ids, logits_output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner): def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids) batch.prepare_for_decode(input_token_ids)
forward_batch = batch.get_forward_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch) logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, batch).tolist() next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
return next_token_ids, logits_output.next_token_logits return next_token_ids, logits_output.next_token_logits
...@@ -357,7 +366,6 @@ def latency_test( ...@@ -357,7 +366,6 @@ def latency_test(
tp_rank, tp_rank,
): ):
configure_logger(server_args, prefix=f" TP{tp_rank}") configure_logger(server_args, prefix=f" TP{tp_rank}")
_set_envs_and_config(server_args)
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model # Load the model
...@@ -463,6 +471,7 @@ def plot_latency_test( ...@@ -463,6 +471,7 @@ def plot_latency_test(
def main(server_args, bench_args): def main(server_args, bench_args):
_set_envs_and_config(server_args)
if server_args.model_path: if server_args.model_path:
if bench_args.correctness_test: if bench_args.correctness_test:
...@@ -513,8 +522,6 @@ if __name__ == "__main__": ...@@ -513,8 +522,6 @@ if __name__ == "__main__":
format="%(message)s", format="%(message)s",
) )
multiprocessing.set_start_method("spawn", force=True)
try: try:
main(server_args, bench_args) main(server_args, bench_args)
except Exception as e: except Exception as e:
......
...@@ -62,7 +62,11 @@ class LogitsMetadata: ...@@ -62,7 +62,11 @@ class LogitsMetadata:
@classmethod @classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch): def from_forward_batch(cls, forward_batch: ForwardBatch):
if forward_batch.return_logprob:
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
else:
return_top_logprob = False
if forward_batch.forward_mode.is_extend(): if forward_batch.forward_mode.is_extend():
extend_logprob_pruned_lens_cpu = [ extend_logprob_pruned_lens_cpu = [
extend_len - start_len extend_len - start_len
......
from __future__ import annotations
""" """
Copyright 2023-2024 SGLang Team Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
...@@ -15,7 +13,19 @@ See the License for the specific language governing permissions and ...@@ -15,7 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""Meta data for requests and batches""" """
Store information about requests and batches.
The following is the flow of data structures for a batch:
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
...@@ -29,7 +39,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap ...@@ -29,7 +39,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -105,6 +115,8 @@ class FINISH_ABORT(BaseFinishReason): ...@@ -105,6 +115,8 @@ class FINISH_ABORT(BaseFinishReason):
@dataclass @dataclass
class ImageInputs: class ImageInputs:
"""The image related inputs."""
pixel_values: torch.Tensor pixel_values: torch.Tensor
image_hash: int image_hash: int
image_sizes: Optional[list] = None image_sizes: Optional[list] = None
...@@ -137,7 +149,7 @@ class ImageInputs: ...@@ -137,7 +149,7 @@ class ImageInputs:
class Req: class Req:
"""Store all inforamtion of a request.""" """The input and output status of a request."""
def __init__( def __init__(
self, self,
...@@ -393,20 +405,20 @@ class ScheduleBatch: ...@@ -393,20 +405,20 @@ class ScheduleBatch:
sampling_info: SamplingBatchInfo = None sampling_info: SamplingBatchInfo = None
# Batched arguments to model runner # Batched arguments to model runner
input_ids: torch.Tensor = None input_ids: List[int] = None
req_pool_indices: torch.Tensor = None req_pool_indices: List[int] = None
seq_lens: torch.Tensor = None seq_lens: List[int] = None
position_ids_offsets: torch.Tensor = None
out_cache_loc: torch.Tensor = None out_cache_loc: torch.Tensor = None
extend_num_tokens: int = None
# For mixed chunekd prefill
prefix_lens_cpu: List[int] = None
running_bs: int = None
# For processing logprobs # For processing logprobs
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: List[int] = None top_logprobs_nums: Optional[List[int]] = None
# For extend and mixed chunekd prefill
prefix_lens: List[int] = None
extend_lens: List[int] = None
extend_num_tokens: int = None
running_bs: int = None
# Stream # Stream
has_stream: bool = False has_stream: bool = False
...@@ -466,12 +478,12 @@ class ScheduleBatch: ...@@ -466,12 +478,12 @@ class ScheduleBatch:
seq_lens = [] seq_lens = []
# Allocate memory # Allocate memory
req_pool_indices_cpu = self.alloc_req_slots(bs) req_pool_indices = self.alloc_req_slots(bs)
out_cache_loc = self.alloc_token_slots(extend_num_tokens) out_cache_loc = self.alloc_token_slots(extend_num_tokens)
pt = 0 pt = 0
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
req.req_pool_idx = req_pool_indices_cpu[i] req.req_pool_idx = req_pool_indices[i]
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
seq_lens.append(seq_len) seq_lens.append(seq_len)
assert seq_len - pre_len == req.extend_input_len assert seq_len - pre_len == req.extend_input_len
...@@ -497,22 +509,19 @@ class ScheduleBatch: ...@@ -497,22 +509,19 @@ class ScheduleBatch:
pt += req.extend_input_len pt += req.extend_input_len
# Set fields # Set fields
with torch.device("cuda"): self.input_ids = sum(input_ids, [])
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32) self.req_pool_indices = torch.tensor(req_pool_indices, device="cuda")
self.req_pool_indices = torch.tensor(req_pool_indices_cpu) self.seq_lens = torch.tensor(seq_lens, device="cuda")
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
self.extend_num_tokens = extend_num_tokens self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
if self.return_logprob:
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs] self.prefix_lens = [len(r.prefix_indices) for r in reqs]
self.extend_lens_cpu = [r.extend_input_len for r in reqs] self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs] self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
def get_forward_batch(self): self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
return ForwardBatch.from_schedule_batch(self)
def mix_with_running(self, running_batch: "ScheduleBatch"): def mix_with_running(self, running_batch: "ScheduleBatch"):
self.forward_mode = ForwardMode.MIXED self.forward_mode = ForwardMode.MIXED
...@@ -522,24 +531,24 @@ class ScheduleBatch: ...@@ -522,24 +531,24 @@ class ScheduleBatch:
req.fill_ids = req.origin_input_ids + req.output_ids req.fill_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = 1 req.extend_input_len = 1
input_ids = torch.cat([self.input_ids, running_batch.input_ids]) input_ids = self.input_ids + running_batch.input_ids
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
extend_num_tokens = self.extend_num_tokens + running_bs extend_num_tokens = self.extend_num_tokens + running_bs
self.merge(running_batch) self.merge_batch(running_batch)
self.input_ids = input_ids self.input_ids = input_ids
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
self.extend_num_tokens = extend_num_tokens self.extend_num_tokens = extend_num_tokens
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens_cpu.extend( self.prefix_lens.extend(
[ [
len(r.origin_input_ids) + len(r.output_ids) - 1 len(r.origin_input_ids) + len(r.output_ids) - 1
for r in running_batch.reqs for r in running_batch.reqs
] ]
) )
self.extend_lens_cpu.extend([1] * running_bs) self.extend_lens.extend([1] * running_bs)
self.extend_logprob_start_lens_cpu.extend([0] * running_bs) self.extend_logprob_start_lens.extend([0] * running_bs)
def check_decode_mem(self): def check_decode_mem(self):
bs = len(self.reqs) bs = len(self.reqs)
...@@ -631,7 +640,7 @@ class ScheduleBatch: ...@@ -631,7 +640,7 @@ class ScheduleBatch:
return retracted_reqs, new_estimate_ratio return retracted_reqs, new_estimate_ratio
def check_for_jump_forward(self, model_runner): def check_for_jump_forward(self, pad_input_ids_func):
jump_forward_reqs = [] jump_forward_reqs = []
filter_indices = [i for i in range(len(self.reqs))] filter_indices = [i for i in range(len(self.reqs))]
...@@ -688,7 +697,7 @@ class ScheduleBatch: ...@@ -688,7 +697,7 @@ class ScheduleBatch:
# re-applying image padding # re-applying image padding
if req.image_inputs is not None: if req.image_inputs is not None:
req.origin_input_ids = model_runner.model.pad_input_ids( req.origin_input_ids = pad_input_ids_func(
req.origin_input_ids_unpadded, req.image_inputs req.origin_input_ids_unpadded, req.image_inputs
) )
...@@ -708,7 +717,7 @@ class ScheduleBatch: ...@@ -708,7 +717,7 @@ class ScheduleBatch:
for r in self.reqs for r in self.reqs
] ]
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda") self.input_ids = input_ids
self.seq_lens.add_(1) self.seq_lens.add_(1)
# Alloc mem # Alloc mem
...@@ -731,32 +740,97 @@ class ScheduleBatch: ...@@ -731,32 +740,97 @@ class ScheduleBatch:
self.reqs = [self.reqs[i] for i in unfinished_indices] self.reqs = [self.reqs[i] for i in unfinished_indices]
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda") new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
self.seq_lens = self.seq_lens[new_indices]
self.input_ids = None
self.req_pool_indices = self.req_pool_indices[new_indices] self.req_pool_indices = self.req_pool_indices[new_indices]
self.position_ids_offsets = self.position_ids_offsets[new_indices] self.seq_lens = self.seq_lens[new_indices]
self.out_cache_loc = None self.out_cache_loc = None
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
self.return_logprob = any(req.return_logprob for req in self.reqs) self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob:
self.top_logprobs_nums = [
self.top_logprobs_nums[i] for i in unfinished_indices
]
self.has_stream = any(req.stream for req in self.reqs) self.has_stream = any(req.stream for req in self.reqs)
self.sampling_info.filter(unfinished_indices, new_indices) self.sampling_info.filter_batch(unfinished_indices, new_indices)
def merge(self, other: "ScheduleBatch"): def merge_batch(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
# needs to be called with pre-merged Batch.reqs. # needs to be called with pre-merged Batch.reqs.
self.sampling_info.merge(other.sampling_info) self.sampling_info.merge_batch(other.sampling_info)
self.reqs.extend(other.reqs) self.reqs.extend(other.reqs)
self.req_pool_indices = torch.concat( self.req_pool_indices = torch.concat(
[self.req_pool_indices, other.req_pool_indices] [self.req_pool_indices, other.req_pool_indices]
) )
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens]) self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
self.position_ids_offsets = torch.concat(
[self.position_ids_offsets, other.position_ids_offsets]
)
self.out_cache_loc = None self.out_cache_loc = None
self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.return_logprob = any(req.return_logprob for req in self.reqs) self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob and other.return_logprob:
self.top_logprobs_nums.extend(other.top_logprobs_nums)
elif self.return_logprob:
self.top_logprobs_nums.extend([0] * len(other.reqs))
elif other.return_logprob:
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
self.has_stream = any(req.stream for req in self.reqs) self.has_stream = any(req.stream for req in self.reqs)
def get_model_worker_batch(self):
if self.forward_mode.is_decode():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
image_inputs
) = None
else:
extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
image_inputs = [r.image_inputs for r in self.reqs]
lora_paths = [req.lora_path for req in self.reqs]
self.sampling_info.regex_fsm_states = [req.regex_fsm_state for req in self.reqs]
return ModelWorkerBatch(
forward_mode=self.forward_mode,
input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens,
image_inputs=image_inputs,
lora_paths=lora_paths,
sampling_info=self.sampling_info,
)
@dataclass
class ModelWorkerBatch:
# The forward mode
forward_mode: ForwardMode
# The input ids
input_ids: List[int]
# The indices of requests in the req_to_token_pool
req_pool_indices: torch.Tensor
# The sequence length
seq_lens: torch.Tensor
# The indices of output tokens in the token_to_kv_pool
out_cache_loc: torch.Tensor
# For logprob
return_logprob: bool
top_logprobs_nums: Optional[List[int]]
# For extend
extend_seq_lens: Optional[List[int]]
extend_prefix_lens: Optional[List[int]]
extend_logprob_start_lens: Optional[List[int]]
# For multimodal
image_inputs: Optional[List[ImageInputs]]
# For LoRA
lora_paths: Optional[List[str]]
# Sampling info
sampling_info: SamplingBatchInfo
...@@ -141,6 +141,9 @@ class Scheduler: ...@@ -141,6 +141,9 @@ class Scheduler:
nccl_port=port_args.nccl_ports[0], nccl_port=port_args.nccl_ports[0],
) )
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
self.pad_input_ids_func = getattr(
self.tp_worker.model_runner.model, "pad_input_ids", None
)
# Get token and memory info from the tp worker # Get token and memory info from the tp worker
( (
...@@ -292,7 +295,7 @@ class Scheduler: ...@@ -292,7 +295,7 @@ class Scheduler:
if self.running_batch is None: if self.running_batch is None:
self.running_batch = new_batch self.running_batch = new_batch
else: else:
self.running_batch.merge(new_batch) self.running_batch.merge_batch(new_batch)
else: else:
# Run a decode batch # Run a decode batch
if self.running_batch is not None: if self.running_batch is not None:
...@@ -370,7 +373,7 @@ class Scheduler: ...@@ -370,7 +373,7 @@ class Scheduler:
req.image_inputs = ImageInputs.from_dict( req.image_inputs = ImageInputs.from_dict(
recv_req.image_inputs, self.model_config.vocab_size recv_req.image_inputs, self.model_config.vocab_size
) )
req.origin_input_ids = self.tp_worker.model_runner.model.pad_input_ids( req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids_unpadded, req.image_inputs req.origin_input_ids_unpadded, req.image_inputs
) )
...@@ -575,9 +578,9 @@ class Scheduler: ...@@ -575,9 +578,9 @@ class Scheduler:
if self.is_generation: if self.is_generation:
# Forward and sample the next tokens # Forward and sample the next tokens
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
forward_batch = batch.get_forward_batch() model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation( logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
forward_batch, batch model_worker_batch
) )
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids next_token_ids
...@@ -641,8 +644,8 @@ class Scheduler: ...@@ -641,8 +644,8 @@ class Scheduler:
) )
else: else:
assert batch.extend_num_tokens != 0 assert batch.extend_num_tokens != 0
forward_batch = batch.get_forward_batch() model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(forward_batch) embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
# Check finish conditions # Check finish conditions
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
...@@ -759,9 +762,7 @@ class Scheduler: ...@@ -759,9 +762,7 @@ class Scheduler:
# Check for jump-forward # Check for jump-forward
if not self.disable_regex_jump_forward: if not self.disable_regex_jump_forward:
jump_forward_reqs = batch.check_for_jump_forward( jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
self.tp_worker.model_runner
)
self.waiting_queue.extend(jump_forward_reqs) self.waiting_queue.extend(jump_forward_reqs)
if batch.is_empty(): if batch.is_empty():
return return
...@@ -771,9 +772,9 @@ class Scheduler: ...@@ -771,9 +772,9 @@ class Scheduler:
batch.prepare_for_decode() batch.prepare_for_decode()
# Forward and sample the next tokens # Forward and sample the next tokens
forward_batch = batch.get_forward_batch() model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation( logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
forward_batch, batch model_worker_batch
) )
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids next_token_ids
......
...@@ -21,6 +21,7 @@ import logging ...@@ -21,6 +21,7 @@ import logging
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightReqInput from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -108,12 +109,14 @@ class TpModelWorker: ...@@ -108,12 +109,14 @@ class TpModelWorker:
self.random_seed, self.random_seed,
) )
def forward_batch_generation(self, forward_batch: ForwardBatch, batch): def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
next_token_ids = self.model_runner.sample(logits_output, batch) next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
return logits_output, next_token_ids return logits_output, next_token_ids
def forward_batch_embedding(self, forward_batch: ForwardBatch): def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings.tolist() embeddings = logits_output.embeddings.tolist()
return embeddings return embeddings
......
...@@ -15,18 +15,33 @@ See the License for the specific language governing permissions and ...@@ -15,18 +15,33 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""Meta data for a forward pass.""" """
Store information about a forward batch.
The following is the flow of data structures for a batch:
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List, Optional
import numpy as np import numpy as np
import torch import torch
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention_backend import AttentionBackend from sglang.srt.layers.attention_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
class ForwardMode(IntEnum): class ForwardMode(IntEnum):
...@@ -69,25 +84,28 @@ class ForwardBatch: ...@@ -69,25 +84,28 @@ class ForwardBatch:
# The indices of output tokens in the token_to_kv_pool # The indices of output tokens in the token_to_kv_pool
out_cache_loc: torch.Tensor out_cache_loc: torch.Tensor
# For logprob
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
# Position information # Position information
positions: torch.Tensor = None positions: torch.Tensor = None
# For extend # For extend
extend_seq_lens: torch.Tensor = None extend_seq_lens: Optional[torch.Tensor] = None
extend_prefix_lens: torch.Tensor = None extend_prefix_lens: Optional[torch.Tensor] = None
extend_start_loc: torch.Tensor = None extend_start_loc: Optional[torch.Tensor] = None
extend_seq_lens_cpu: Optional[List[int]] = None
# For logprob extend_logprob_start_lens_cpu: Optional[List[int]] = None
return_logprob: bool = False
top_logprobs_nums: List[int] = None
extend_seq_lens_cpu: List[int] = None
extend_logprob_start_lens_cpu: List[int] = None
# For multimodal # For multimodal
image_inputs: List[ImageInputs] = None image_inputs: Optional[List[ImageInputs]] = None
# For LoRA # For LoRA
lora_paths: List[str] = None lora_paths: Optional[List[str]] = None
# Sampling info
sampling_info: SamplingBatchInfo = None
# Attention backend # Attention backend
req_to_token_pool: ReqToTokenPool = None req_to_token_pool: ReqToTokenPool = None
...@@ -95,42 +113,61 @@ class ForwardBatch: ...@@ -95,42 +113,61 @@ class ForwardBatch:
attn_backend: AttentionBackend = None attn_backend: AttentionBackend = None
@classmethod @classmethod
def from_schedule_batch( def init_new(
cls, cls,
batch: ScheduleBatch, batch: ModelWorkerBatch,
model_runner: ModelRunner,
): ):
device = "cuda"
ret = cls( ret = cls(
forward_mode=batch.forward_mode, forward_mode=batch.forward_mode,
batch_size=batch.batch_size(), batch_size=len(batch.seq_lens),
input_ids=batch.input_ids, input_ids=torch.tensor(batch.input_ids, dtype=torch.int32, device=device),
req_pool_indices=batch.req_pool_indices, req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens, seq_lens=batch.seq_lens,
out_cache_loc=batch.out_cache_loc, out_cache_loc=batch.out_cache_loc,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
lora_paths=[req.lora_path for req in batch.reqs], lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
) )
# Init position information
if ret.forward_mode.is_decode(): if ret.forward_mode.is_decode():
ret.positions = (ret.seq_lens - 1).to(torch.int64) ret.positions = (ret.seq_lens - 1).to(torch.int64)
else: else:
ret.positions = torch.tensor( ret.positions = torch.tensor(
np.concatenate( np.concatenate(
[ [
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids)) np.arange(prefix_len, prefix_len + extend_len)
for i, req in enumerate(batch.reqs) for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
], ],
axis=0, axis=0,
), ),
device="cuda", device=device,
).to(torch.int64) ).to(torch.int64)
ret.image_inputs = [r.image_inputs for r in batch.reqs] ret.image_inputs = batch.image_inputs
ret.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda") ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
ret.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, device=device
)
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens) ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0) ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
ret.extend_seq_lens_cpu = batch.extend_lens_cpu ret.extend_seq_lens_cpu = batch.extend_seq_lens
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
# Init attention information
ret.req_to_token_pool = model_runner.req_to_token_pool
ret.token_to_kv_pool = model_runner.token_to_kv_pool
ret.attn_backend = model_runner.attn_backend
model_runner.attn_backend.init_forward_metadata(ret)
# Init lora information
if model_runner.server_args.lora_paths is not None:
model_runner.lora_manager.prepare_lora_batch(ret)
return ret return ret
...@@ -21,7 +21,7 @@ import importlib.resources ...@@ -21,7 +21,7 @@ import importlib.resources
import logging import logging
import pkgutil import pkgutil
from functools import lru_cache from functools import lru_cache
from typing import Optional, Tuple, Type from typing import Optional, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -38,11 +38,12 @@ from vllm.model_executor.model_loader import get_model ...@@ -38,11 +38,12 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.constrained import disable_cache
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.sampler import Sampler
from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool, MHATokenToKVPool,
MLATokenToKVPool, MLATokenToKVPool,
...@@ -52,6 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch ...@@ -52,6 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
enable_show_time_cost,
get_available_gpu_memory, get_available_gpu_memory,
is_generation_model, is_generation_model,
is_multimodal_model, is_multimodal_model,
...@@ -102,6 +104,12 @@ class ModelRunner: ...@@ -102,6 +104,12 @@ class ModelRunner:
server_args.chunked_prefill_size = None server_args.chunked_prefill_size = None
server_args.mem_fraction_static *= 0.95 server_args.mem_fraction_static *= 0.95
# Global vars
if server_args.show_time_cost:
enable_show_time_cost()
if server_args.disable_disk_cache:
disable_cache()
global_server_args_dict.update( global_server_args_dict.update(
{ {
"attention_backend": server_args.attention_backend, "attention_backend": server_args.attention_backend,
...@@ -491,16 +499,6 @@ class ModelRunner: ...@@ -491,16 +499,6 @@ class ModelRunner:
) )
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput: def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
# Attach attention information
forward_batch.req_to_token_pool = self.req_to_token_pool
forward_batch.token_to_kv_pool = self.token_to_kv_pool
forward_batch.attn_backend = self.attn_backend
forward_batch.attn_backend.init_forward_metadata(forward_batch)
# Attach lora information
if self.server_args.lora_paths is not None:
self.lora_manager.prepare_lora_batch(forward_batch)
if forward_batch.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch) return self.forward_decode(forward_batch)
elif forward_batch.forward_mode.is_extend(): elif forward_batch.forward_mode.is_extend():
...@@ -508,16 +506,27 @@ class ModelRunner: ...@@ -508,16 +506,27 @@ class ModelRunner:
else: else:
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}") raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
def _apply_logits_bias( def sample(
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
): ) -> torch.Tensor:
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info = forward_batch.sampling_info
sampling_info.update_regex_vocab_mask()
sampling_info.update_penalties()
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
# Sample the next tokens.
next_token_ids = self.sampler(logits, sampling_info)
return next_token_ids
def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
# Apply logit_bias # Apply logit_bias
if sampling_info.logit_bias is not None: if sampling_info.logit_bias is not None:
logits.add_(sampling_info.logit_bias) logits.add_(sampling_info.logit_bias)
# min-token, presence, frequency # min-token, presence, frequency
if sampling_info.linear_penalties is not None: if sampling_info.linear_penalties is not None:
logits += sampling_info.linear_penalties logits.add_(sampling_info.linear_penalties)
# repetition # repetition
if sampling_info.scaling_penalties is not None: if sampling_info.scaling_penalties is not None:
...@@ -533,20 +542,6 @@ class ModelRunner: ...@@ -533,20 +542,6 @@ class ModelRunner:
return logits return logits
def sample(
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
) -> torch.Tensor:
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
batch.sampling_info.update_regex_vocab_mask(batch)
batch.sampling_info.update_penalties()
logits = self._apply_logits_bias(
logits_output.next_token_logits, batch.sampling_info
)
# Sample the next tokens.
next_token_ids = self.sampler(logits, batch.sampling_info)
return next_token_ids
@lru_cache() @lru_cache()
def import_model_classes(): def import_model_classes():
......
...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, List ...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, List
import torch import torch
import sglang.srt.sampling.penaltylib as penaltylib import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.constrained import RegexGuide
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
...@@ -22,13 +23,17 @@ class SamplingBatchInfo: ...@@ -22,13 +23,17 @@ class SamplingBatchInfo:
top_ks: torch.Tensor = None top_ks: torch.Tensor = None
min_ps: torch.Tensor = None min_ps: torch.Tensor = None
# Dispatch in CUDA graph
need_min_p_sampling: bool = False
# Bias Tensors # Bias Tensors
logit_bias: torch.Tensor = None logit_bias: torch.Tensor = None
vocab_mask: torch.Tensor = None vocab_mask: torch.Tensor = None
# FSM states
regex_fsms: List[RegexGuide] = None
regex_fsm_states: List[int] = None
# Dispatch in CUDA graph
need_min_p_sampling: bool = False
# Penalizer # Penalizer
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
linear_penalties: torch.Tensor = None linear_penalties: torch.Tensor = None
...@@ -54,6 +59,8 @@ class SamplingBatchInfo: ...@@ -54,6 +59,8 @@ class SamplingBatchInfo:
[r.sampling_params.min_p for r in reqs], dtype=torch.float [r.sampling_params.min_p for r in reqs], dtype=torch.float
) )
ret.regex_fsms = [r.regex_fsm for r in reqs]
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs) ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at # Each penalizers will do nothing if they evaluate themselves as not required by looking at
...@@ -102,24 +109,22 @@ class SamplingBatchInfo: ...@@ -102,24 +109,22 @@ class SamplingBatchInfo:
) )
self.linear_penalties = penalizer.apply(self.linear_penalties) self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self, batch: ScheduleBatch): def update_regex_vocab_mask(self):
has_regex = any(req.regex_fsm is not None for req in batch.reqs)
# Reset the vocab mask # Reset the vocab mask
self.vocab_mask = None self.vocab_mask = None
if has_regex: if any(regex_fsm is not None for regex_fsm in self.regex_fsms):
self.vocab_mask = torch.zeros( self.vocab_mask = torch.zeros(
batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda" len(self.regex_fsms), self.vocab_size, dtype=torch.bool, device="cuda"
) )
for i, req in enumerate(batch.reqs): for i, regex_fsm in enumerate(self.regex_fsms):
if req.regex_fsm is not None: if regex_fsm is not None:
self.vocab_mask[i].fill_(1) self.vocab_mask[i].fill_(1)
self.vocab_mask[i][ self.vocab_mask[i][
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
] = 0 ] = 0
def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor): def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
self.penalizer_orchestrator.filter(unfinished_indices, new_indices) self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
for item in [ for item in [
...@@ -129,9 +134,11 @@ class SamplingBatchInfo: ...@@ -129,9 +134,11 @@ class SamplingBatchInfo:
"min_ps", "min_ps",
"logit_bias", "logit_bias",
]: ]:
self_val = getattr(self, item, None) value = getattr(self, item, None)
if self_val is not None: # logit_bias can be None if value is not None: # logit_bias can be None
setattr(self, item, self_val[new_indices]) setattr(self, item, value[new_indices])
self.regex_fsms = [self.regex_fsms[i] for i in new_indices]
@staticmethod @staticmethod
def merge_bias_tensor( def merge_bias_tensor(
...@@ -153,7 +160,7 @@ class SamplingBatchInfo: ...@@ -153,7 +160,7 @@ class SamplingBatchInfo:
return None return None
def merge(self, other: "SamplingBatchInfo"): def merge_batch(self, other: "SamplingBatchInfo"):
self.penalizer_orchestrator.merge(other.penalizer_orchestrator) self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
for item in [ for item in [
...@@ -169,3 +176,5 @@ class SamplingBatchInfo: ...@@ -169,3 +176,5 @@ class SamplingBatchInfo:
self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other) self.logit_bias, other.logit_bias, len(self), len(other)
) )
self.regex_fsms.extend(other.regex_fsms)
...@@ -41,7 +41,6 @@ from fastapi.middleware.cors import CORSMiddleware ...@@ -41,7 +41,6 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.constrained import disable_cache
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
...@@ -72,8 +71,6 @@ from sglang.srt.utils import ( ...@@ -72,8 +71,6 @@ from sglang.srt.utils import (
allocate_init_ports, allocate_init_ports,
assert_pkg_version, assert_pkg_version,
configure_logger, configure_logger,
enable_show_time_cost,
is_hip,
kill_child_process, kill_child_process,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
prepare_model_and_tokenizer, prepare_model_and_tokenizer,
...@@ -400,14 +397,6 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -400,14 +397,6 @@ def _set_envs_and_config(server_args: ServerArgs):
# Set ulimit # Set ulimit
set_ulimit() set_ulimit()
# Enable show time cost for debugging
if server_args.show_time_cost:
enable_show_time_cost()
# Disable disk cache
if server_args.disable_disk_cache:
disable_cache()
# Fix triton bugs # Fix triton bugs
if server_args.tp_size * server_args.dp_size > 1: if server_args.tp_size * server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment