Unverified Commit 63ba2f8d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up batch data structures: Introducing ModelWorkerBatch (#1544)

parent 36d5acfc
......@@ -62,11 +62,13 @@ import torch.distributed as dist
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server import _set_envs_and_config
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
allocate_init_ports,
configure_logger,
kill_child_process,
suppress_other_loggers,
......@@ -125,6 +127,11 @@ def load_model(server_args, tp_rank):
suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port,
server_args.additional_ports,
server_args.dp_size,
)
model_config = ModelConfig(
server_args.model_path,
server_args.trust_remote_code,
......@@ -136,7 +143,7 @@ def load_model(server_args, tp_rank):
gpu_id=tp_rank,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=28888,
nccl_port=server_args.additional_ports[-1],
server_args=server_args,
)
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
......@@ -225,17 +232,19 @@ def extend(reqs, model_runner):
tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
forward_batch = batch.get_forward_batch()
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
return next_token_ids, logits_output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids)
forward_batch = batch.get_forward_batch()
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
return next_token_ids, logits_output.next_token_logits
......@@ -357,7 +366,6 @@ def latency_test(
tp_rank,
):
configure_logger(server_args, prefix=f" TP{tp_rank}")
_set_envs_and_config(server_args)
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model
......@@ -463,6 +471,7 @@ def plot_latency_test(
def main(server_args, bench_args):
_set_envs_and_config(server_args)
if server_args.model_path:
if bench_args.correctness_test:
......@@ -513,8 +522,6 @@ if __name__ == "__main__":
format="%(message)s",
)
multiprocessing.set_start_method("spawn", force=True)
try:
main(server_args, bench_args)
except Exception as e:
......
......@@ -62,7 +62,11 @@ class LogitsMetadata:
@classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch):
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
if forward_batch.return_logprob:
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
else:
return_top_logprob = False
if forward_batch.forward_mode.is_extend():
extend_logprob_pruned_lens_cpu = [
extend_len - start_len
......
from __future__ import annotations
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
......@@ -15,7 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""Meta data for requests and batches"""
"""
Store information about requests and batches.
The following is the flow of data structures for a batch:
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
import logging
from dataclasses import dataclass
......@@ -29,7 +39,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
......@@ -105,6 +115,8 @@ class FINISH_ABORT(BaseFinishReason):
@dataclass
class ImageInputs:
"""The image related inputs."""
pixel_values: torch.Tensor
image_hash: int
image_sizes: Optional[list] = None
......@@ -137,7 +149,7 @@ class ImageInputs:
class Req:
"""Store all inforamtion of a request."""
"""The input and output status of a request."""
def __init__(
self,
......@@ -393,20 +405,20 @@ class ScheduleBatch:
sampling_info: SamplingBatchInfo = None
# Batched arguments to model runner
input_ids: torch.Tensor = None
req_pool_indices: torch.Tensor = None
seq_lens: torch.Tensor = None
position_ids_offsets: torch.Tensor = None
input_ids: List[int] = None
req_pool_indices: List[int] = None
seq_lens: List[int] = None
out_cache_loc: torch.Tensor = None
extend_num_tokens: int = None
# For mixed chunekd prefill
prefix_lens_cpu: List[int] = None
running_bs: int = None
# For processing logprobs
return_logprob: bool = False
top_logprobs_nums: List[int] = None
top_logprobs_nums: Optional[List[int]] = None
# For extend and mixed chunekd prefill
prefix_lens: List[int] = None
extend_lens: List[int] = None
extend_num_tokens: int = None
running_bs: int = None
# Stream
has_stream: bool = False
......@@ -466,12 +478,12 @@ class ScheduleBatch:
seq_lens = []
# Allocate memory
req_pool_indices_cpu = self.alloc_req_slots(bs)
req_pool_indices = self.alloc_req_slots(bs)
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
pt = 0
for i, req in enumerate(reqs):
req.req_pool_idx = req_pool_indices_cpu[i]
req.req_pool_idx = req_pool_indices[i]
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
seq_lens.append(seq_len)
assert seq_len - pre_len == req.extend_input_len
......@@ -497,22 +509,19 @@ class ScheduleBatch:
pt += req.extend_input_len
# Set fields
with torch.device("cuda"):
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
self.input_ids = sum(input_ids, [])
self.req_pool_indices = torch.tensor(req_pool_indices, device="cuda")
self.seq_lens = torch.tensor(seq_lens, device="cuda")
self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
self.extend_lens_cpu = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
if self.return_logprob:
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
def get_forward_batch(self):
return ForwardBatch.from_schedule_batch(self)
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
def mix_with_running(self, running_batch: "ScheduleBatch"):
self.forward_mode = ForwardMode.MIXED
......@@ -522,24 +531,24 @@ class ScheduleBatch:
req.fill_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = 1
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
input_ids = self.input_ids + running_batch.input_ids
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
extend_num_tokens = self.extend_num_tokens + running_bs
self.merge(running_batch)
self.merge_batch(running_batch)
self.input_ids = input_ids
self.out_cache_loc = out_cache_loc
self.extend_num_tokens = extend_num_tokens
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens_cpu.extend(
self.prefix_lens.extend(
[
len(r.origin_input_ids) + len(r.output_ids) - 1
for r in running_batch.reqs
]
)
self.extend_lens_cpu.extend([1] * running_bs)
self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
self.extend_lens.extend([1] * running_bs)
self.extend_logprob_start_lens.extend([0] * running_bs)
def check_decode_mem(self):
bs = len(self.reqs)
......@@ -631,7 +640,7 @@ class ScheduleBatch:
return retracted_reqs, new_estimate_ratio
def check_for_jump_forward(self, model_runner):
def check_for_jump_forward(self, pad_input_ids_func):
jump_forward_reqs = []
filter_indices = [i for i in range(len(self.reqs))]
......@@ -688,7 +697,7 @@ class ScheduleBatch:
# re-applying image padding
if req.image_inputs is not None:
req.origin_input_ids = model_runner.model.pad_input_ids(
req.origin_input_ids = pad_input_ids_func(
req.origin_input_ids_unpadded, req.image_inputs
)
......@@ -708,7 +717,7 @@ class ScheduleBatch:
for r in self.reqs
]
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
self.input_ids = input_ids
self.seq_lens.add_(1)
# Alloc mem
......@@ -731,32 +740,97 @@ class ScheduleBatch:
self.reqs = [self.reqs[i] for i in unfinished_indices]
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
self.seq_lens = self.seq_lens[new_indices]
self.input_ids = None
self.req_pool_indices = self.req_pool_indices[new_indices]
self.position_ids_offsets = self.position_ids_offsets[new_indices]
self.seq_lens = self.seq_lens[new_indices]
self.out_cache_loc = None
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob:
self.top_logprobs_nums = [
self.top_logprobs_nums[i] for i in unfinished_indices
]
self.has_stream = any(req.stream for req in self.reqs)
self.sampling_info.filter(unfinished_indices, new_indices)
self.sampling_info.filter_batch(unfinished_indices, new_indices)
def merge(self, other: "ScheduleBatch"):
def merge_batch(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
# needs to be called with pre-merged Batch.reqs.
self.sampling_info.merge(other.sampling_info)
self.sampling_info.merge_batch(other.sampling_info)
self.reqs.extend(other.reqs)
self.req_pool_indices = torch.concat(
[self.req_pool_indices, other.req_pool_indices]
)
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
self.position_ids_offsets = torch.concat(
[self.position_ids_offsets, other.position_ids_offsets]
)
self.out_cache_loc = None
self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob and other.return_logprob:
self.top_logprobs_nums.extend(other.top_logprobs_nums)
elif self.return_logprob:
self.top_logprobs_nums.extend([0] * len(other.reqs))
elif other.return_logprob:
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
self.has_stream = any(req.stream for req in self.reqs)
def get_model_worker_batch(self):
if self.forward_mode.is_decode():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
image_inputs
) = None
else:
extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
image_inputs = [r.image_inputs for r in self.reqs]
lora_paths = [req.lora_path for req in self.reqs]
self.sampling_info.regex_fsm_states = [req.regex_fsm_state for req in self.reqs]
return ModelWorkerBatch(
forward_mode=self.forward_mode,
input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens,
image_inputs=image_inputs,
lora_paths=lora_paths,
sampling_info=self.sampling_info,
)
@dataclass
class ModelWorkerBatch:
# The forward mode
forward_mode: ForwardMode
# The input ids
input_ids: List[int]
# The indices of requests in the req_to_token_pool
req_pool_indices: torch.Tensor
# The sequence length
seq_lens: torch.Tensor
# The indices of output tokens in the token_to_kv_pool
out_cache_loc: torch.Tensor
# For logprob
return_logprob: bool
top_logprobs_nums: Optional[List[int]]
# For extend
extend_seq_lens: Optional[List[int]]
extend_prefix_lens: Optional[List[int]]
extend_logprob_start_lens: Optional[List[int]]
# For multimodal
image_inputs: Optional[List[ImageInputs]]
# For LoRA
lora_paths: Optional[List[str]]
# Sampling info
sampling_info: SamplingBatchInfo
......@@ -141,6 +141,9 @@ class Scheduler:
nccl_port=port_args.nccl_ports[0],
)
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
self.pad_input_ids_func = getattr(
self.tp_worker.model_runner.model, "pad_input_ids", None
)
# Get token and memory info from the tp worker
(
......@@ -292,7 +295,7 @@ class Scheduler:
if self.running_batch is None:
self.running_batch = new_batch
else:
self.running_batch.merge(new_batch)
self.running_batch.merge_batch(new_batch)
else:
# Run a decode batch
if self.running_batch is not None:
......@@ -370,7 +373,7 @@ class Scheduler:
req.image_inputs = ImageInputs.from_dict(
recv_req.image_inputs, self.model_config.vocab_size
)
req.origin_input_ids = self.tp_worker.model_runner.model.pad_input_ids(
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids_unpadded, req.image_inputs
)
......@@ -575,9 +578,9 @@ class Scheduler:
if self.is_generation:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
forward_batch = batch.get_forward_batch()
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
forward_batch, batch
model_worker_batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
......@@ -641,8 +644,8 @@ class Scheduler:
)
else:
assert batch.extend_num_tokens != 0
forward_batch = batch.get_forward_batch()
embeddings = self.tp_worker.forward_batch_embedding(forward_batch)
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
# Check finish conditions
for i, req in enumerate(batch.reqs):
......@@ -759,9 +762,7 @@ class Scheduler:
# Check for jump-forward
if not self.disable_regex_jump_forward:
jump_forward_reqs = batch.check_for_jump_forward(
self.tp_worker.model_runner
)
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
self.waiting_queue.extend(jump_forward_reqs)
if batch.is_empty():
return
......@@ -771,9 +772,9 @@ class Scheduler:
batch.prepare_for_decode()
# Forward and sample the next tokens
forward_batch = batch.get_forward_batch()
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
forward_batch, batch
model_worker_batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
......
......@@ -21,6 +21,7 @@ import logging
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
......@@ -108,12 +109,14 @@ class TpModelWorker:
self.random_seed,
)
def forward_batch_generation(self, forward_batch: ForwardBatch, batch):
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
next_token_ids = self.model_runner.sample(logits_output, batch)
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
return logits_output, next_token_ids
def forward_batch_embedding(self, forward_batch: ForwardBatch):
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings.tolist()
return embeddings
......
......@@ -15,18 +15,33 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""Meta data for a forward pass."""
"""
Store information about a forward batch.
The following is the flow of data structures for a batch:
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Optional
import numpy as np
import torch
if TYPE_CHECKING:
from sglang.srt.layers.attention_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
class ForwardMode(IntEnum):
......@@ -69,25 +84,28 @@ class ForwardBatch:
# The indices of output tokens in the token_to_kv_pool
out_cache_loc: torch.Tensor
# For logprob
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
# Position information
positions: torch.Tensor = None
# For extend
extend_seq_lens: torch.Tensor = None
extend_prefix_lens: torch.Tensor = None
extend_start_loc: torch.Tensor = None
# For logprob
return_logprob: bool = False
top_logprobs_nums: List[int] = None
extend_seq_lens_cpu: List[int] = None
extend_logprob_start_lens_cpu: List[int] = None
extend_seq_lens: Optional[torch.Tensor] = None
extend_prefix_lens: Optional[torch.Tensor] = None
extend_start_loc: Optional[torch.Tensor] = None
extend_seq_lens_cpu: Optional[List[int]] = None
extend_logprob_start_lens_cpu: Optional[List[int]] = None
# For multimodal
image_inputs: List[ImageInputs] = None
image_inputs: Optional[List[ImageInputs]] = None
# For LoRA
lora_paths: List[str] = None
lora_paths: Optional[List[str]] = None
# Sampling info
sampling_info: SamplingBatchInfo = None
# Attention backend
req_to_token_pool: ReqToTokenPool = None
......@@ -95,42 +113,61 @@ class ForwardBatch:
attn_backend: AttentionBackend = None
@classmethod
def from_schedule_batch(
def init_new(
cls,
batch: ScheduleBatch,
batch: ModelWorkerBatch,
model_runner: ModelRunner,
):
device = "cuda"
ret = cls(
forward_mode=batch.forward_mode,
batch_size=batch.batch_size(),
input_ids=batch.input_ids,
batch_size=len(batch.seq_lens),
input_ids=torch.tensor(batch.input_ids, dtype=torch.int32, device=device),
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
out_cache_loc=batch.out_cache_loc,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
lora_paths=[req.lora_path for req in batch.reqs],
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
)
# Init position information
if ret.forward_mode.is_decode():
ret.positions = (ret.seq_lens - 1).to(torch.int64)
else:
ret.positions = torch.tensor(
np.concatenate(
[
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
for i, req in enumerate(batch.reqs)
np.arange(prefix_len, prefix_len + extend_len)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
),
device="cuda",
device=device,
).to(torch.int64)
ret.image_inputs = [r.image_inputs for r in batch.reqs]
ret.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
ret.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
ret.image_inputs = batch.image_inputs
ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, device=device
)
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
ret.extend_seq_lens_cpu = batch.extend_lens_cpu
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
ret.extend_seq_lens_cpu = batch.extend_seq_lens
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
# Init attention information
ret.req_to_token_pool = model_runner.req_to_token_pool
ret.token_to_kv_pool = model_runner.token_to_kv_pool
ret.attn_backend = model_runner.attn_backend
model_runner.attn_backend.init_forward_metadata(ret)
# Init lora information
if model_runner.server_args.lora_paths is not None:
model_runner.lora_manager.prepare_lora_batch(ret)
return ret
......@@ -21,7 +21,7 @@ import importlib.resources
import logging
import pkgutil
from functools import lru_cache
from typing import Optional, Tuple, Type
from typing import Optional, Type
import torch
import torch.nn as nn
......@@ -38,11 +38,12 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.constrained import disable_cache
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import Sampler
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
......@@ -52,6 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
enable_show_time_cost,
get_available_gpu_memory,
is_generation_model,
is_multimodal_model,
......@@ -102,6 +104,12 @@ class ModelRunner:
server_args.chunked_prefill_size = None
server_args.mem_fraction_static *= 0.95
# Global vars
if server_args.show_time_cost:
enable_show_time_cost()
if server_args.disable_disk_cache:
disable_cache()
global_server_args_dict.update(
{
"attention_backend": server_args.attention_backend,
......@@ -491,16 +499,6 @@ class ModelRunner:
)
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
# Attach attention information
forward_batch.req_to_token_pool = self.req_to_token_pool
forward_batch.token_to_kv_pool = self.token_to_kv_pool
forward_batch.attn_backend = self.attn_backend
forward_batch.attn_backend.init_forward_metadata(forward_batch)
# Attach lora information
if self.server_args.lora_paths is not None:
self.lora_manager.prepare_lora_batch(forward_batch)
if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch)
elif forward_batch.forward_mode.is_extend():
......@@ -508,16 +506,27 @@ class ModelRunner:
else:
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
def _apply_logits_bias(
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
):
def sample(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
) -> torch.Tensor:
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info = forward_batch.sampling_info
sampling_info.update_regex_vocab_mask()
sampling_info.update_penalties()
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
# Sample the next tokens.
next_token_ids = self.sampler(logits, sampling_info)
return next_token_ids
def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
# Apply logit_bias
if sampling_info.logit_bias is not None:
logits.add_(sampling_info.logit_bias)
# min-token, presence, frequency
if sampling_info.linear_penalties is not None:
logits += sampling_info.linear_penalties
logits.add_(sampling_info.linear_penalties)
# repetition
if sampling_info.scaling_penalties is not None:
......@@ -533,20 +542,6 @@ class ModelRunner:
return logits
def sample(
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
) -> torch.Tensor:
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
batch.sampling_info.update_regex_vocab_mask(batch)
batch.sampling_info.update_penalties()
logits = self._apply_logits_bias(
logits_output.next_token_logits, batch.sampling_info
)
# Sample the next tokens.
next_token_ids = self.sampler(logits, batch.sampling_info)
return next_token_ids
@lru_cache()
def import_model_classes():
......
......@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, List
import torch
import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.constrained import RegexGuide
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
......@@ -22,13 +23,17 @@ class SamplingBatchInfo:
top_ks: torch.Tensor = None
min_ps: torch.Tensor = None
# Dispatch in CUDA graph
need_min_p_sampling: bool = False
# Bias Tensors
logit_bias: torch.Tensor = None
vocab_mask: torch.Tensor = None
# FSM states
regex_fsms: List[RegexGuide] = None
regex_fsm_states: List[int] = None
# Dispatch in CUDA graph
need_min_p_sampling: bool = False
# Penalizer
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
linear_penalties: torch.Tensor = None
......@@ -54,6 +59,8 @@ class SamplingBatchInfo:
[r.sampling_params.min_p for r in reqs], dtype=torch.float
)
ret.regex_fsms = [r.regex_fsm for r in reqs]
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
......@@ -102,24 +109,22 @@ class SamplingBatchInfo:
)
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self, batch: ScheduleBatch):
has_regex = any(req.regex_fsm is not None for req in batch.reqs)
def update_regex_vocab_mask(self):
# Reset the vocab mask
self.vocab_mask = None
if has_regex:
if any(regex_fsm is not None for regex_fsm in self.regex_fsms):
self.vocab_mask = torch.zeros(
batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda"
len(self.regex_fsms), self.vocab_size, dtype=torch.bool, device="cuda"
)
for i, req in enumerate(batch.reqs):
if req.regex_fsm is not None:
for i, regex_fsm in enumerate(self.regex_fsms):
if regex_fsm is not None:
self.vocab_mask[i].fill_(1)
self.vocab_mask[i][
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
] = 0
def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
for item in [
......@@ -129,9 +134,11 @@ class SamplingBatchInfo:
"min_ps",
"logit_bias",
]:
self_val = getattr(self, item, None)
if self_val is not None: # logit_bias can be None
setattr(self, item, self_val[new_indices])
value = getattr(self, item, None)
if value is not None: # logit_bias can be None
setattr(self, item, value[new_indices])
self.regex_fsms = [self.regex_fsms[i] for i in new_indices]
@staticmethod
def merge_bias_tensor(
......@@ -153,7 +160,7 @@ class SamplingBatchInfo:
return None
def merge(self, other: "SamplingBatchInfo"):
def merge_batch(self, other: "SamplingBatchInfo"):
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
for item in [
......@@ -169,3 +176,5 @@ class SamplingBatchInfo:
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other)
)
self.regex_fsms.extend(other.regex_fsms)
......@@ -41,7 +41,6 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.constrained import disable_cache
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
from sglang.srt.managers.io_struct import (
......@@ -72,8 +71,6 @@ from sglang.srt.utils import (
allocate_init_ports,
assert_pkg_version,
configure_logger,
enable_show_time_cost,
is_hip,
kill_child_process,
maybe_set_triton_cache_manager,
prepare_model_and_tokenizer,
......@@ -400,14 +397,6 @@ def _set_envs_and_config(server_args: ServerArgs):
# Set ulimit
set_ulimit()
# Enable show time cost for debugging
if server_args.show_time_cost:
enable_show_time_cost()
# Disable disk cache
if server_args.disable_disk_cache:
disable_cache()
# Fix triton bugs
if server_args.tp_size * server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment