Unverified Commit ffd20fcd authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Make constrained decoding work for overlap scheduler (#2095)

parent 55bd97f3
import logging import logging
import os
from typing import Union from typing import Union
import torch import torch
......
...@@ -136,6 +136,7 @@ class ImageInputs: ...@@ -136,6 +136,7 @@ class ImageInputs:
image_embeds: Optional[List[torch.Tensor]] = None image_embeds: Optional[List[torch.Tensor]] = None
aspect_ratio_ids: Optional[List[torch.Tensor]] = None aspect_ratio_ids: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None aspect_ratio_mask: Optional[List[torch.Tensor]] = None
# QWen2-VL related # QWen2-VL related
image_grid_thws: List[Tuple[int, int, int]] = None image_grid_thws: List[Tuple[int, int, int]] = None
mrope_position_delta: Optional[torch.Tensor] = None mrope_position_delta: Optional[torch.Tensor] = None
...@@ -187,11 +188,10 @@ class Req: ...@@ -187,11 +188,10 @@ class Req:
self.origin_input_ids = origin_input_ids self.origin_input_ids = origin_input_ids
self.output_ids = [] # Each decode stage's output ids self.output_ids = [] # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids self.fill_ids = None # fill_ids = origin_input_ids + output_ids
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.lora_path = lora_path self.lora_path = lora_path
# Memory info # Memory pool info
self.req_pool_idx = None self.req_pool_idx = None
# Check finish # Check finish
...@@ -428,7 +428,7 @@ bid = 0 ...@@ -428,7 +428,7 @@ bid = 0
@dataclasses.dataclass @dataclasses.dataclass
class ScheduleBatch: class ScheduleBatch:
"""Store all inforamtion of a batch.""" """Store all inforamtion of a batch on the scheduler."""
# Request, memory pool, and cache # Request, memory pool, and cache
reqs: List[Req] reqs: List[Req]
...@@ -438,9 +438,9 @@ class ScheduleBatch: ...@@ -438,9 +438,9 @@ class ScheduleBatch:
# For utility # For utility
model_config: ModelConfig = None model_config: ModelConfig = None
forward_mode: ForwardMode = None forward_mode: ForwardMode = None
sampling_info: SamplingBatchInfo = None sampling_info: SamplingBatchInfo = None
next_batch_sampling_info: SamplingBatchInfo = None
# Batched arguments to model runner # Batched arguments to model runner
input_ids: torch.Tensor = None input_ids: torch.Tensor = None
...@@ -509,7 +509,7 @@ class ScheduleBatch: ...@@ -509,7 +509,7 @@ class ScheduleBatch:
def is_empty(self): def is_empty(self):
return len(self.reqs) == 0 return len(self.reqs) == 0
def alloc_req_slots(self, num_reqs): def alloc_req_slots(self, num_reqs: int):
req_pool_indices = self.req_to_token_pool.alloc(num_reqs) req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None: if req_pool_indices is None:
raise RuntimeError( raise RuntimeError(
...@@ -610,7 +610,7 @@ class ScheduleBatch: ...@@ -610,7 +610,7 @@ class ScheduleBatch:
assert len(self.out_cache_loc) == self.extend_num_tokens assert len(self.out_cache_loc) == self.extend_num_tokens
def prepare_for_extend(self): def prepare_for_extend(self, enable_overlap_schedule: bool = False):
self.forward_mode = ForwardMode.EXTEND self.forward_mode = ForwardMode.EXTEND
bs = len(self.reqs) bs = len(self.reqs)
...@@ -704,7 +704,7 @@ class ScheduleBatch: ...@@ -704,7 +704,7 @@ class ScheduleBatch:
self.sampling_info = SamplingBatchInfo.from_schedule_batch( self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self, self,
self.model_config.vocab_size, self.model_config.vocab_size,
global_server_args_dict["disable_penalizer"], enable_overlap_schedule=enable_overlap_schedule,
) )
def mix_with_running(self, running_batch: "ScheduleBatch"): def mix_with_running(self, running_batch: "ScheduleBatch"):
...@@ -746,6 +746,7 @@ class ScheduleBatch: ...@@ -746,6 +746,7 @@ class ScheduleBatch:
return False return False
def retract_decode(self): def retract_decode(self):
"""Retract the decoding requests when there is not enough memory."""
sorted_indices = [i for i in range(len(self.reqs))] sorted_indices = [i for i in range(len(self.reqs))]
# TODO(lsyin): improve retraction policy for radix cache # TODO(lsyin): improve retraction policy for radix cache
...@@ -886,18 +887,10 @@ class ScheduleBatch: ...@@ -886,18 +887,10 @@ class ScheduleBatch:
def prepare_for_idle(self): def prepare_for_idle(self):
self.forward_mode = ForwardMode.IDLE self.forward_mode = ForwardMode.IDLE
self.input_ids = torch.empty(0, dtype=torch.int32).to( self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
self.device, non_blocking=True self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
) self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens = torch.empty(0, dtype=torch.int32).to( self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
self.device, non_blocking=True
)
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.req_pool_indices = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens_sum = 0 self.seq_lens_sum = 0
self.extend_num_tokens = 0 self.extend_num_tokens = 0
...@@ -1063,7 +1056,6 @@ class ScheduleBatch: ...@@ -1063,7 +1056,6 @@ class ScheduleBatch:
out_cache_loc=self.out_cache_loc, out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob, return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs, decoding_reqs=self.decoding_reqs,
sampling_info=self.sampling_info,
) )
def __str__(self): def __str__(self):
......
...@@ -15,6 +15,7 @@ limitations under the License. ...@@ -15,6 +15,7 @@ limitations under the License.
"""A scheduler that manages a tensor parallel GPU worker.""" """A scheduler that manages a tensor parallel GPU worker."""
import dataclasses
import logging import logging
import os import os
import threading import threading
...@@ -63,6 +64,7 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient ...@@ -63,6 +64,7 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
broadcast_pyobj, broadcast_pyobj,
...@@ -220,8 +222,12 @@ class Scheduler: ...@@ -220,8 +222,12 @@ class Scheduler:
# Init running status # Init running status
self.waiting_queue: List[Req] = [] self.waiting_queue: List[Req] = []
# The running decoding batch for continuous batching
self.running_batch: Optional[ScheduleBatch] = None self.running_batch: Optional[ScheduleBatch] = None
# The current forward batch
self.cur_batch: Optional[ScheduleBatch] = None self.cur_batch: Optional[ScheduleBatch] = None
# The current forward batch
self.last_batch: Optional[ScheduleBatch] = None
self.forward_ct = 0 self.forward_ct = 0
self.forward_ct_decode = 0 self.forward_ct_decode = 0
self.num_generated_tokens = 0 self.num_generated_tokens = 0
...@@ -336,15 +342,12 @@ class Scheduler: ...@@ -336,15 +342,12 @@ class Scheduler:
@torch.no_grad() @torch.no_grad()
def event_loop_normal(self): def event_loop_normal(self):
"""A normal blocking scheduler loop.""" """A normal scheduler loop."""
self.last_batch = None
while True: while True:
recv_reqs = self.recv_requests() recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run() batch = self.get_next_batch_to_run()
if self.server_args.enable_dp_attention: if self.server_args.enable_dp_attention:
batch = self.prepare_dp_attn_batch(batch) batch = self.prepare_dp_attn_batch(batch)
...@@ -353,20 +356,8 @@ class Scheduler: ...@@ -353,20 +356,8 @@ class Scheduler:
if batch: if batch:
result = self.run_batch(batch) result = self.run_batch(batch)
self.process_batch_result(batch, result) self.process_batch_result(batch, result)
# Decode multiple steps to reduce the overhead
if batch.forward_mode.is_decode():
for _ in range(self.server_args.num_continuous_decode_steps - 1):
if not self.running_batch:
break
self.update_running_batch()
if not self.running_batch:
break
if self.server_args.enable_dp_attention:
batch = self.prepare_dp_attn_batch(batch)
result = self.run_batch(batch)
self.process_batch_result(batch, result)
else: else:
# Self-check and re-init some states when the server is idle
self.check_memory() self.check_memory()
self.new_token_ratio = self.init_new_token_ratio self.new_token_ratio = self.init_new_token_ratio
...@@ -377,9 +368,6 @@ class Scheduler: ...@@ -377,9 +368,6 @@ class Scheduler:
"""A scheduler loop that overlaps the CPU processing and GPU computation.""" """A scheduler loop that overlaps the CPU processing and GPU computation."""
result_queue = deque() result_queue = deque()
self.last_batch = None
self.running_batch = None
while True: while True:
recv_reqs = self.recv_requests() recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
...@@ -390,10 +378,24 @@ class Scheduler: ...@@ -390,10 +378,24 @@ class Scheduler:
result = self.run_batch(batch) result = self.run_batch(batch)
result_queue.append((batch.copy(), result)) result_queue.append((batch.copy(), result))
if self.last_batch is None:
# A dummy first batch to start the pipeline for overlap scheduler.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.process_batch_result(tmp_batch, None)
if self.last_batch: if self.last_batch:
tmp_batch, tmp_result = result_queue.popleft() tmp_batch, tmp_result = result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result(tmp_batch, tmp_result) self.process_batch_result(tmp_batch, tmp_result)
elif batch is None: elif batch is None:
# Self-check and re-init some states when the server is idle
self.check_memory() self.check_memory()
self.new_token_ratio = self.init_new_token_ratio self.new_token_ratio = self.init_new_token_ratio
...@@ -806,7 +808,7 @@ class Scheduler: ...@@ -806,7 +808,7 @@ class Scheduler:
self.tree_cache, self.tree_cache,
self.model_config, self.model_config,
) )
new_batch.prepare_for_extend() new_batch.prepare_for_extend(self.enable_overlap)
# Mixed-style chunked prefill # Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None: if self.is_mixed_chunk and self.running_batch is not None:
...@@ -893,14 +895,15 @@ class Scheduler: ...@@ -893,14 +895,15 @@ class Scheduler:
return ret return ret
def process_batch_result(self, batch: ScheduleBatch, result): def process_batch_result(self, batch: ScheduleBatch, result):
if batch.forward_mode.is_idle():
return
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result) self.process_batch_result_decode(batch, result)
if batch.is_empty(): if batch.is_empty():
self.running_batch = None self.running_batch = None
else: elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result) self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
batch.next_batch_sampling_info.sampling_info_done.set()
def process_batch_result_prefill(self, batch: ScheduleBatch, result): def process_batch_result_prefill(self, batch: ScheduleBatch, result):
...@@ -953,6 +956,10 @@ class Scheduler: ...@@ -953,6 +956,10 @@ class Scheduler:
else: else:
req.is_being_chunked -= 1 req.is_being_chunked -= 1
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
batch.next_batch_sampling_info.sampling_info_done.set()
else: # embedding or reward model else: # embedding or reward model
embeddings, bid = result embeddings, bid = result
embeddings = embeddings.tolist() embeddings = embeddings.tolist()
...@@ -1022,6 +1029,10 @@ class Scheduler: ...@@ -1022,6 +1029,10 @@ class Scheduler:
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs) self.stream_output(batch.reqs)
self.token_to_kv_pool.free_group_end() self.token_to_kv_pool.free_group_end()
......
...@@ -18,7 +18,6 @@ limitations under the License. ...@@ -18,7 +18,6 @@ limitations under the License.
import dataclasses import dataclasses
import logging import logging
import threading import threading
import time
from queue import Queue from queue import Queue
from typing import Optional from typing import Optional
...@@ -96,9 +95,7 @@ class TpModelWorkerClient: ...@@ -96,9 +95,7 @@ class TpModelWorkerClient:
@torch.no_grad() @torch.no_grad()
def forward_thread_func_(self): def forward_thread_func_(self):
while True: while True:
model_worker_batch, future_token_ids_ct, compute_info_done = ( model_worker_batch, future_token_ids_ct = self.input_queue.get()
self.input_queue.get()
)
if not model_worker_batch: if not model_worker_batch:
break break
self.launch_done = threading.Event() self.launch_done = threading.Event()
...@@ -109,7 +106,6 @@ class TpModelWorkerClient: ...@@ -109,7 +106,6 @@ class TpModelWorkerClient:
resolve_future_token_ids(input_ids, self.future_token_ids_map) resolve_future_token_ids(input_ids, self.future_token_ids_map)
# Run forward # Run forward
compute_info_done.wait()
logits_output, next_token_ids = self.worker.forward_batch_generation( logits_output, next_token_ids = self.worker.forward_batch_generation(
model_worker_batch, self.launch_done model_worker_batch, self.launch_done
) )
...@@ -160,15 +156,16 @@ class TpModelWorkerClient: ...@@ -160,15 +156,16 @@ class TpModelWorkerClient:
return logits_output, next_token_ids return logits_output, next_token_ids
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
# A cuda stream sync here to avoid the cuda illegal memory access error.
_ = model_worker_batch.seq_lens[0].item()
# Push a new batch to the queue # Push a new batch to the queue
model_worker_batch.sampling_info = dataclasses.replace( model_worker_batch.sampling_info = dataclasses.replace(
model_worker_batch.sampling_info model_worker_batch.sampling_info,
) sampling_info_done=threading.Event(),
compute_info_done = torch.cuda.Event()
compute_info_done.record()
self.input_queue.put(
(model_worker_batch, self.future_token_ids_ct, compute_info_done)
) )
self.cur_sampling_info = model_worker_batch.sampling_info
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
# Allocate output future objects # Allocate output future objects
bs = len(model_worker_batch.seq_lens) bs = len(model_worker_batch.seq_lens)
......
...@@ -52,15 +52,19 @@ if TYPE_CHECKING: ...@@ -52,15 +52,19 @@ if TYPE_CHECKING:
class ForwardMode(IntEnum): class ForwardMode(IntEnum):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
PREFILL = auto() PREFILL = auto()
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt). # Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
EXTEND = auto() EXTEND = auto()
# Decode one token. # Decode one token.
DECODE = auto() DECODE = auto()
# Contains both EXTEND and DECODE. # Contains both EXTEND and DECODE when doing chunked prefill.
MIXED = auto() MIXED = auto()
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence allocated. # No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
IDLE = auto() IDLE = auto()
# A dummy first batch to start the pipeline for overlap scheduler.
# It is now used for triggering the sampling_info_done event for the first prefill batch.
DUMMY_FIRST = auto()
def is_prefill(self): def is_prefill(self):
return self == ForwardMode.PREFILL return self == ForwardMode.PREFILL
...@@ -76,6 +80,9 @@ class ForwardMode(IntEnum): ...@@ -76,6 +80,9 @@ class ForwardMode(IntEnum):
def is_idle(self): def is_idle(self):
return self == ForwardMode.IDLE return self == ForwardMode.IDLE
def is_dummy_first(self):
return self == ForwardMode.DUMMY_FIRST
@dataclass @dataclass
class ForwardBatch: class ForwardBatch:
......
...@@ -142,7 +142,6 @@ class ModelRunner: ...@@ -142,7 +142,6 @@ class ModelRunner:
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"disable_mla": server_args.disable_mla, "disable_mla": server_args.disable_mla,
"torchao_config": server_args.torchao_config, "torchao_config": server_args.torchao_config,
"disable_penalizer": server_args.disable_penalizer,
"enable_nan_detection": server_args.enable_nan_detection, "enable_nan_detection": server_args.enable_nan_detection,
"enable_dp_attention": server_args.enable_dp_attention, "enable_dp_attention": server_args.enable_dp_attention,
} }
...@@ -636,10 +635,18 @@ class ModelRunner: ...@@ -636,10 +635,18 @@ class ModelRunner:
def sample( def sample(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
) -> torch.Tensor: ) -> torch.Tensor:
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info = forward_batch.sampling_info sampling_info = forward_batch.sampling_info
sampling_info.update_regex_vocab_mask()
sampling_info.update_penalties() if sampling_info.sampling_info_done:
# Overlap mode: the function update_regex_vocab_mask was executed
# in process_batch_result of the last batch.
if sampling_info.grammars:
sampling_info.sampling_info_done.wait()
sampling_info.update_penalties()
else:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info.update_regex_vocab_mask()
sampling_info.update_penalties()
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info) logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
# Sample the next tokens. # Sample the next tokens.
......
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import logging
import threading
from typing import TYPE_CHECKING, Callable, List, Optional from typing import TYPE_CHECKING, Callable, List, Optional
import torch import torch
import sglang.srt.sampling.penaltylib as penaltylib import sglang.srt.sampling.penaltylib as penaltylib
logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
...@@ -28,6 +33,7 @@ class SamplingBatchInfo: ...@@ -28,6 +33,7 @@ class SamplingBatchInfo:
# Bias Tensors # Bias Tensors
vocab_size: int vocab_size: int
grammars: Optional[List] = None grammars: Optional[List] = None
sampling_info_done: Optional[threading.Event] = None
logit_bias: torch.Tensor = None logit_bias: torch.Tensor = None
vocab_mask: Optional[torch.Tensor] = None vocab_mask: Optional[torch.Tensor] = None
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
...@@ -42,10 +48,7 @@ class SamplingBatchInfo: ...@@ -42,10 +48,7 @@ class SamplingBatchInfo:
@classmethod @classmethod
def from_schedule_batch( def from_schedule_batch(
cls, cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
batch: ScheduleBatch,
vocab_size: int,
disable_penalizer: bool,
): ):
reqs = batch.reqs reqs = batch.reqs
device = batch.device device = batch.device
...@@ -79,6 +82,33 @@ class SamplingBatchInfo: ...@@ -79,6 +82,33 @@ class SamplingBatchInfo:
) )
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
if enable_overlap_schedule:
# TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs,
# so it is kind of tricky to make it work with overlap scheduler.
# It requires correcly updating the penalty logits before the sampling and syncing the events.
# We will support them later.
penalizers = {
penaltylib.BatchedMinNewTokensPenalizer,
}
if (
any(req.sampling_params.frequency_penalty != 0.0 for req in reqs)
or any(req.sampling_params.presence_penalty != 0.0 for req in reqs)
or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs)
):
logger.warning(
"frequency_penalty, presence_penalty, and repetition_penalty are not supported "
"when using the default overlap scheduler. They will be ignored. "
"Please add `--disable-overlap` when launching the server if you need these features. "
"The speed will be slower in that case."
)
else:
penalizers = {
penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer,
penaltylib.BatchedPresencePenalizer,
penaltylib.BatchedRepetitionPenalizer,
}
# Each penalizers will do nothing if they evaluate themselves as not required by looking at # Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
# should not add hefty computation overhead other than simple checks. # should not add hefty computation overhead other than simple checks.
...@@ -86,20 +116,12 @@ class SamplingBatchInfo: ...@@ -86,20 +116,12 @@ class SamplingBatchInfo:
# While we choose not to even create the class instances if they are not required, this # While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to # could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge_batch()} cases as well. # handle {filter_batch()} and {merge_batch()} cases as well.
if disable_penalizer: ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
ret.penalizer_orchestrator = None vocab_size=vocab_size,
else: batch=batch,
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( device=batch.device,
vocab_size=vocab_size, Penalizers=penalizers,
batch=batch, )
device=batch.device,
Penalizers={
penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer,
penaltylib.BatchedPresencePenalizer,
penaltylib.BatchedRepetitionPenalizer,
},
)
# Handle logit bias but only allocate when needed # Handle logit bias but only allocate when needed
ret.logit_bias = None ret.logit_bias = None
...@@ -133,13 +155,13 @@ class SamplingBatchInfo: ...@@ -133,13 +155,13 @@ class SamplingBatchInfo:
self.linear_penalties = penalizer.apply(self.linear_penalties) self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self): def update_regex_vocab_mask(self):
if not self.grammars or not any(grammar for grammar in self.grammars): if not self.grammars:
self.vocab_mask = None self.vocab_mask = None
self.apply_mask = None self.apply_mask = None
return return
# find a grammar from the list # find a grammar from the list
grammar = next(grammar for grammar in self.grammars if grammar is not None) grammar = next(grammar for grammar in self.grammars if grammar)
# maybe we can reuse the existing mask? # maybe we can reuse the existing mask?
self.vocab_mask = grammar.allocate_vocab_mask( self.vocab_mask = grammar.allocate_vocab_mask(
......
...@@ -123,7 +123,6 @@ class ServerArgs: ...@@ -123,7 +123,6 @@ class ServerArgs:
disable_disk_cache: bool = False disable_disk_cache: bool = False
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
disable_mla: bool = False disable_mla: bool = False
disable_penalizer: bool = False
enable_overlap_schedule: bool = False enable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False enable_mixed_chunk: bool = False
enable_dp_attention: bool = False enable_dp_attention: bool = False
...@@ -200,12 +199,7 @@ class ServerArgs: ...@@ -200,12 +199,7 @@ class ServerArgs:
) )
if self.enable_overlap_schedule: if self.enable_overlap_schedule:
logger.warning( self.disable_jump_forward = True
"Overlap scheduler mode is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results. "
)
self.disable_penalizer = True
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -622,11 +616,6 @@ class ServerArgs: ...@@ -622,11 +616,6 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.", help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
) )
parser.add_argument(
"--disable-penalizer",
action="store_true",
help="Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests.",
)
parser.add_argument( parser.add_argument(
"--disable-nan-detection", "--disable-nan-detection",
action="store_true", action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment