Unverified Commit ffd20fcd authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Make constrained decoding work for overlap scheduler (#2095)

parent 55bd97f3
import logging
import os
from typing import Union
import torch
......
......@@ -136,6 +136,7 @@ class ImageInputs:
image_embeds: Optional[List[torch.Tensor]] = None
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
# QWen2-VL related
image_grid_thws: List[Tuple[int, int, int]] = None
mrope_position_delta: Optional[torch.Tensor] = None
......@@ -187,11 +188,10 @@ class Req:
self.origin_input_ids = origin_input_ids
self.output_ids = [] # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
self.sampling_params = sampling_params
self.lora_path = lora_path
# Memory info
# Memory pool info
self.req_pool_idx = None
# Check finish
......@@ -428,7 +428,7 @@ bid = 0
@dataclasses.dataclass
class ScheduleBatch:
"""Store all inforamtion of a batch."""
"""Store all inforamtion of a batch on the scheduler."""
# Request, memory pool, and cache
reqs: List[Req]
......@@ -438,9 +438,9 @@ class ScheduleBatch:
# For utility
model_config: ModelConfig = None
forward_mode: ForwardMode = None
sampling_info: SamplingBatchInfo = None
next_batch_sampling_info: SamplingBatchInfo = None
# Batched arguments to model runner
input_ids: torch.Tensor = None
......@@ -509,7 +509,7 @@ class ScheduleBatch:
def is_empty(self):
return len(self.reqs) == 0
def alloc_req_slots(self, num_reqs):
def alloc_req_slots(self, num_reqs: int):
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None:
raise RuntimeError(
......@@ -610,7 +610,7 @@ class ScheduleBatch:
assert len(self.out_cache_loc) == self.extend_num_tokens
def prepare_for_extend(self):
def prepare_for_extend(self, enable_overlap_schedule: bool = False):
self.forward_mode = ForwardMode.EXTEND
bs = len(self.reqs)
......@@ -704,7 +704,7 @@ class ScheduleBatch:
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
global_server_args_dict["disable_penalizer"],
enable_overlap_schedule=enable_overlap_schedule,
)
def mix_with_running(self, running_batch: "ScheduleBatch"):
......@@ -746,6 +746,7 @@ class ScheduleBatch:
return False
def retract_decode(self):
"""Retract the decoding requests when there is not enough memory."""
sorted_indices = [i for i in range(len(self.reqs))]
# TODO(lsyin): improve retraction policy for radix cache
......@@ -886,18 +887,10 @@ class ScheduleBatch:
def prepare_for_idle(self):
self.forward_mode = ForwardMode.IDLE
self.input_ids = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.req_pool_indices = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens_sum = 0
self.extend_num_tokens = 0
......@@ -1063,7 +1056,6 @@ class ScheduleBatch:
out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs,
sampling_info=self.sampling_info,
)
def __str__(self):
......
......@@ -15,6 +15,7 @@ limitations under the License.
"""A scheduler that manages a tensor parallel GPU worker."""
import dataclasses
import logging
import os
import threading
......@@ -63,6 +64,7 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
broadcast_pyobj,
......@@ -220,8 +222,12 @@ class Scheduler:
# Init running status
self.waiting_queue: List[Req] = []
# The running decoding batch for continuous batching
self.running_batch: Optional[ScheduleBatch] = None
# The current forward batch
self.cur_batch: Optional[ScheduleBatch] = None
# The current forward batch
self.last_batch: Optional[ScheduleBatch] = None
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
......@@ -336,15 +342,12 @@ class Scheduler:
@torch.no_grad()
def event_loop_normal(self):
"""A normal blocking scheduler loop."""
self.last_batch = None
"""A normal scheduler loop."""
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
if self.server_args.enable_dp_attention:
batch = self.prepare_dp_attn_batch(batch)
......@@ -353,20 +356,8 @@ class Scheduler:
if batch:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
# Decode multiple steps to reduce the overhead
if batch.forward_mode.is_decode():
for _ in range(self.server_args.num_continuous_decode_steps - 1):
if not self.running_batch:
break
self.update_running_batch()
if not self.running_batch:
break
if self.server_args.enable_dp_attention:
batch = self.prepare_dp_attn_batch(batch)
result = self.run_batch(batch)
self.process_batch_result(batch, result)
else:
# Self-check and re-init some states when the server is idle
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
......@@ -377,9 +368,6 @@ class Scheduler:
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
result_queue = deque()
self.last_batch = None
self.running_batch = None
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
......@@ -390,10 +378,24 @@ class Scheduler:
result = self.run_batch(batch)
result_queue.append((batch.copy(), result))
if self.last_batch is None:
# A dummy first batch to start the pipeline for overlap scheduler.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.process_batch_result(tmp_batch, None)
if self.last_batch:
tmp_batch, tmp_result = result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result(tmp_batch, tmp_result)
elif batch is None:
# Self-check and re-init some states when the server is idle
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
......@@ -806,7 +808,7 @@ class Scheduler:
self.tree_cache,
self.model_config,
)
new_batch.prepare_for_extend()
new_batch.prepare_for_extend(self.enable_overlap)
# Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None:
......@@ -893,14 +895,15 @@ class Scheduler:
return ret
def process_batch_result(self, batch: ScheduleBatch, result):
if batch.forward_mode.is_idle():
return
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
if batch.is_empty():
self.running_batch = None
else:
elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
batch.next_batch_sampling_info.sampling_info_done.set()
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
......@@ -953,6 +956,10 @@ class Scheduler:
else:
req.is_being_chunked -= 1
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
batch.next_batch_sampling_info.sampling_info_done.set()
else: # embedding or reward model
embeddings, bid = result
embeddings = embeddings.tolist()
......@@ -1022,6 +1029,10 @@ class Scheduler:
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs)
self.token_to_kv_pool.free_group_end()
......
......@@ -18,7 +18,6 @@ limitations under the License.
import dataclasses
import logging
import threading
import time
from queue import Queue
from typing import Optional
......@@ -96,9 +95,7 @@ class TpModelWorkerClient:
@torch.no_grad()
def forward_thread_func_(self):
while True:
model_worker_batch, future_token_ids_ct, compute_info_done = (
self.input_queue.get()
)
model_worker_batch, future_token_ids_ct = self.input_queue.get()
if not model_worker_batch:
break
self.launch_done = threading.Event()
......@@ -109,7 +106,6 @@ class TpModelWorkerClient:
resolve_future_token_ids(input_ids, self.future_token_ids_map)
# Run forward
compute_info_done.wait()
logits_output, next_token_ids = self.worker.forward_batch_generation(
model_worker_batch, self.launch_done
)
......@@ -160,15 +156,16 @@ class TpModelWorkerClient:
return logits_output, next_token_ids
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
# A cuda stream sync here to avoid the cuda illegal memory access error.
_ = model_worker_batch.seq_lens[0].item()
# Push a new batch to the queue
model_worker_batch.sampling_info = dataclasses.replace(
model_worker_batch.sampling_info
)
compute_info_done = torch.cuda.Event()
compute_info_done.record()
self.input_queue.put(
(model_worker_batch, self.future_token_ids_ct, compute_info_done)
model_worker_batch.sampling_info,
sampling_info_done=threading.Event(),
)
self.cur_sampling_info = model_worker_batch.sampling_info
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
# Allocate output future objects
bs = len(model_worker_batch.seq_lens)
......
......@@ -52,15 +52,19 @@ if TYPE_CHECKING:
class ForwardMode(IntEnum):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
PREFILL = auto()
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
# Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
EXTEND = auto()
# Decode one token.
DECODE = auto()
# Contains both EXTEND and DECODE.
# Contains both EXTEND and DECODE when doing chunked prefill.
MIXED = auto()
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence allocated.
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
IDLE = auto()
# A dummy first batch to start the pipeline for overlap scheduler.
# It is now used for triggering the sampling_info_done event for the first prefill batch.
DUMMY_FIRST = auto()
def is_prefill(self):
return self == ForwardMode.PREFILL
......@@ -76,6 +80,9 @@ class ForwardMode(IntEnum):
def is_idle(self):
return self == ForwardMode.IDLE
def is_dummy_first(self):
return self == ForwardMode.DUMMY_FIRST
@dataclass
class ForwardBatch:
......
......@@ -142,7 +142,6 @@ class ModelRunner:
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"disable_mla": server_args.disable_mla,
"torchao_config": server_args.torchao_config,
"disable_penalizer": server_args.disable_penalizer,
"enable_nan_detection": server_args.enable_nan_detection,
"enable_dp_attention": server_args.enable_dp_attention,
}
......@@ -636,8 +635,16 @@ class ModelRunner:
def sample(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
) -> torch.Tensor:
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info = forward_batch.sampling_info
if sampling_info.sampling_info_done:
# Overlap mode: the function update_regex_vocab_mask was executed
# in process_batch_result of the last batch.
if sampling_info.grammars:
sampling_info.sampling_info_done.wait()
sampling_info.update_penalties()
else:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info.update_regex_vocab_mask()
sampling_info.update_penalties()
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
......
from __future__ import annotations
import dataclasses
import logging
import threading
from typing import TYPE_CHECKING, Callable, List, Optional
import torch
import sglang.srt.sampling.penaltylib as penaltylib
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
......@@ -28,6 +33,7 @@ class SamplingBatchInfo:
# Bias Tensors
vocab_size: int
grammars: Optional[List] = None
sampling_info_done: Optional[threading.Event] = None
logit_bias: torch.Tensor = None
vocab_mask: Optional[torch.Tensor] = None
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
......@@ -42,10 +48,7 @@ class SamplingBatchInfo:
@classmethod
def from_schedule_batch(
cls,
batch: ScheduleBatch,
vocab_size: int,
disable_penalizer: bool,
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
):
reqs = batch.reqs
device = batch.device
......@@ -79,6 +82,33 @@ class SamplingBatchInfo:
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
if enable_overlap_schedule:
# TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs,
# so it is kind of tricky to make it work with overlap scheduler.
# It requires correcly updating the penalty logits before the sampling and syncing the events.
# We will support them later.
penalizers = {
penaltylib.BatchedMinNewTokensPenalizer,
}
if (
any(req.sampling_params.frequency_penalty != 0.0 for req in reqs)
or any(req.sampling_params.presence_penalty != 0.0 for req in reqs)
or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs)
):
logger.warning(
"frequency_penalty, presence_penalty, and repetition_penalty are not supported "
"when using the default overlap scheduler. They will be ignored. "
"Please add `--disable-overlap` when launching the server if you need these features. "
"The speed will be slower in that case."
)
else:
penalizers = {
penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer,
penaltylib.BatchedPresencePenalizer,
penaltylib.BatchedRepetitionPenalizer,
}
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
# should not add hefty computation overhead other than simple checks.
......@@ -86,19 +116,11 @@ class SamplingBatchInfo:
# While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge_batch()} cases as well.
if disable_penalizer:
ret.penalizer_orchestrator = None
else:
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
vocab_size=vocab_size,
batch=batch,
device=batch.device,
Penalizers={
penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer,
penaltylib.BatchedPresencePenalizer,
penaltylib.BatchedRepetitionPenalizer,
},
Penalizers=penalizers,
)
# Handle logit bias but only allocate when needed
......@@ -133,13 +155,13 @@ class SamplingBatchInfo:
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self):
if not self.grammars or not any(grammar for grammar in self.grammars):
if not self.grammars:
self.vocab_mask = None
self.apply_mask = None
return
# find a grammar from the list
grammar = next(grammar for grammar in self.grammars if grammar is not None)
grammar = next(grammar for grammar in self.grammars if grammar)
# maybe we can reuse the existing mask?
self.vocab_mask = grammar.allocate_vocab_mask(
......
......@@ -123,7 +123,6 @@ class ServerArgs:
disable_disk_cache: bool = False
disable_custom_all_reduce: bool = False
disable_mla: bool = False
disable_penalizer: bool = False
enable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
......@@ -200,12 +199,7 @@ class ServerArgs:
)
if self.enable_overlap_schedule:
logger.warning(
"Overlap scheduler mode is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results. "
)
self.disable_penalizer = True
self.disable_jump_forward = True
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
......@@ -622,11 +616,6 @@ class ServerArgs:
action="store_true",
help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
)
parser.add_argument(
"--disable-penalizer",
action="store_true",
help="Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests.",
)
parser.add_argument(
"--disable-nan-detection",
action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment