Unverified Commit 9da5a60b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add an option to disable penalizer (#1651)

parent 69aa937a
...@@ -531,7 +531,9 @@ class ScheduleBatch: ...@@ -531,7 +531,9 @@ class ScheduleBatch:
self.extend_lens = [r.extend_input_len for r in reqs] self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size) self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self, vocab_size, global_server_args_dict["disable_penalizer"]
)
def mix_with_running(self, running_batch: "ScheduleBatch"): def mix_with_running(self, running_batch: "ScheduleBatch"):
self.forward_mode = ForwardMode.MIXED self.forward_mode = ForwardMode.MIXED
......
...@@ -671,9 +671,10 @@ class Scheduler: ...@@ -671,9 +671,10 @@ class Scheduler:
def process_batch_result_prefill(self, batch: ScheduleBatch, result): def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if self.is_generation: if self.is_generation:
logits_output, next_token_ids = result logits_output, next_token_ids = result
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( if batch.sampling_info.penalizer_orchestrator:
next_token_ids batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
) next_token_ids
)
if logits_output: if logits_output:
# Move logprobs to cpu # Move logprobs to cpu
...@@ -755,9 +756,10 @@ class Scheduler: ...@@ -755,9 +756,10 @@ class Scheduler:
def process_batch_result_decode(self, batch: ScheduleBatch, result): def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result logits_output, next_token_ids = result
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( if batch.sampling_info.penalizer_orchestrator:
next_token_ids batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
) next_token_ids
)
self.num_generated_tokens += len(batch.reqs) self.num_generated_tokens += len(batch.reqs)
# Move logprobs to cpu # Move logprobs to cpu
......
...@@ -119,6 +119,7 @@ class ModelRunner: ...@@ -119,6 +119,7 @@ class ModelRunner:
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"disable_mla": server_args.disable_mla, "disable_mla": server_args.disable_mla,
"torchao_config": server_args.torchao_config, "torchao_config": server_args.torchao_config,
"disable_penalizer": server_args.disable_penalizer,
} }
) )
......
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List, Optional
import torch import torch
...@@ -33,15 +33,20 @@ class SamplingBatchInfo: ...@@ -33,15 +33,20 @@ class SamplingBatchInfo:
regex_fsm_states: List[int] = None regex_fsm_states: List[int] = None
# Penalizer # Penalizer
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
linear_penalties: torch.Tensor = None linear_penalties: Optional[torch.Tensor] = None
scaling_penalties: torch.Tensor = None scaling_penalties: Optional[torch.Tensor] = None
# Device # Device
device: str = "cuda" device: str = "cuda"
@classmethod @classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): def from_schedule_batch(
cls,
batch: ScheduleBatch,
vocab_size: int,
disable_penalizer: bool,
):
reqs = batch.reqs reqs = batch.reqs
with batch.input_ids.device: with batch.input_ids.device:
temperatures = torch.tensor( temperatures = torch.tensor(
...@@ -76,17 +81,20 @@ class SamplingBatchInfo: ...@@ -76,17 +81,20 @@ class SamplingBatchInfo:
# While we choose not to even create the class instances if they are not required, this # While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to # could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge()} cases as well. # handle {filter_batch()} and {merge()} cases as well.
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( if disable_penalizer:
vocab_size=vocab_size, ret.penalizer_orchestrator = None
batch=batch, else:
device=batch.input_ids.device, ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
Penalizers={ vocab_size=vocab_size,
penaltylib.BatchedFrequencyPenalizer, batch=batch,
penaltylib.BatchedMinNewTokensPenalizer, device=batch.input_ids.device,
penaltylib.BatchedPresencePenalizer, Penalizers={
penaltylib.BatchedRepetitionPenalizer, penaltylib.BatchedFrequencyPenalizer,
}, penaltylib.BatchedMinNewTokensPenalizer,
) penaltylib.BatchedPresencePenalizer,
penaltylib.BatchedRepetitionPenalizer,
},
)
# Handle logit bias but only allocate when needed # Handle logit bias but only allocate when needed
ret.logit_bias = None ret.logit_bias = None
...@@ -97,6 +105,9 @@ class SamplingBatchInfo: ...@@ -97,6 +105,9 @@ class SamplingBatchInfo:
return len(self.temperatures) return len(self.temperatures)
def update_penalties(self): def update_penalties(self):
if not self.penalizer_orchestrator:
return
self.scaling_penalties = None self.scaling_penalties = None
self.linear_penalties = None self.linear_penalties = None
...@@ -117,26 +128,26 @@ class SamplingBatchInfo: ...@@ -117,26 +128,26 @@ class SamplingBatchInfo:
def update_regex_vocab_mask(self): def update_regex_vocab_mask(self):
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms) has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
if not has_regex:
# Reset the vocab mask self.vocab_mask = None
self.vocab_mask = None return
if has_regex: self.vocab_mask = torch.zeros(
self.vocab_mask = torch.zeros( len(self.temperatures),
len(self.temperatures), self.vocab_size,
self.vocab_size, dtype=torch.bool,
dtype=torch.bool, device=self.device,
device=self.device, )
) for i, regex_fsm in enumerate(self.regex_fsms):
for i, regex_fsm in enumerate(self.regex_fsms): if regex_fsm is not None:
if regex_fsm is not None: self.vocab_mask[i].fill_(1)
self.vocab_mask[i].fill_(1) self.vocab_mask[i][
self.vocab_mask[i][ regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens ] = 0
] = 0
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
self.penalizer_orchestrator.filter(unfinished_indices, new_indices) if self.penalizer_orchestrator:
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
for item in [ for item in [
"temperatures", "temperatures",
...@@ -175,7 +186,8 @@ class SamplingBatchInfo: ...@@ -175,7 +186,8 @@ class SamplingBatchInfo:
return None return None
def merge_batch(self, other: "SamplingBatchInfo"): def merge_batch(self, other: "SamplingBatchInfo"):
self.penalizer_orchestrator.merge(other.penalizer_orchestrator) if self.penalizer_orchestrator:
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
for item in [ for item in [
"temperatures", "temperatures",
......
...@@ -35,12 +35,12 @@ class ServerArgs: ...@@ -35,12 +35,12 @@ class ServerArgs:
tokenizer_mode: str = "auto" tokenizer_mode: str = "auto"
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
load_format: str = "auto" load_format: str = "auto"
trust_remote_code: bool = True
dtype: str = "auto" dtype: str = "auto"
device: str = "cuda"
kv_cache_dtype: str = "auto" kv_cache_dtype: str = "auto"
trust_remote_code: bool = True
context_length: Optional[int] = None
quantization: Optional[str] = None quantization: Optional[str] = None
context_length: Optional[int] = None
device: str = "cuda"
served_model_name: Optional[str] = None served_model_name: Optional[str] = None
chat_template: Optional[str] = None chat_template: Optional[str] = None
is_embedding: bool = False is_embedding: bool = False
...@@ -86,10 +86,15 @@ class ServerArgs: ...@@ -86,10 +86,15 @@ class ServerArgs:
# Model override args in JSON # Model override args in JSON
json_model_override_args: str = "{}" json_model_override_args: str = "{}"
# Optimization/debug options # LoRA
lora_paths: Optional[List[str]] = None
max_loras_per_batch: int = 8
# Kernel backend
attention_backend: Optional[str] = None attention_backend: Optional[str] = None
sampling_backend: Optional[str] = None sampling_backend: Optional[str] = None
# Optimization/debug options
disable_flashinfer: bool = False disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False disable_flashinfer_sampling: bool = False
disable_radix_cache: bool = False disable_radix_cache: bool = False
...@@ -99,6 +104,7 @@ class ServerArgs: ...@@ -99,6 +104,7 @@ class ServerArgs:
disable_disk_cache: bool = False disable_disk_cache: bool = False
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
disable_mla: bool = False disable_mla: bool = False
disable_penalizer: bool = False
enable_mixed_chunk: bool = False enable_mixed_chunk: bool = False
enable_torch_compile: bool = False enable_torch_compile: bool = False
max_torch_compile_bs: int = 32 max_torch_compile_bs: int = 32
...@@ -106,10 +112,6 @@ class ServerArgs: ...@@ -106,10 +112,6 @@ class ServerArgs:
enable_p2p_check: bool = False enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False triton_attention_reduce_in_fp32: bool = False
# LoRA
lora_paths: Optional[List[str]] = None
max_loras_per_batch: int = 8
def __post_init__(self): def __post_init__(self):
# Set missing default values # Set missing default values
if self.tokenizer_path is None: if self.tokenizer_path is None:
...@@ -224,6 +226,11 @@ class ServerArgs: ...@@ -224,6 +226,11 @@ class ServerArgs:
'"dummy" will initialize the weights with random values, ' '"dummy" will initialize the weights with random values, '
"which is mainly for profiling.", "which is mainly for profiling.",
) )
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
)
parser.add_argument( parser.add_argument(
"--dtype", "--dtype",
type=str, type=str,
...@@ -238,13 +245,6 @@ class ServerArgs: ...@@ -238,13 +245,6 @@ class ServerArgs:
'* "float" is shorthand for FP32 precision.\n' '* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.', '* "float32" for FP32 precision.',
) )
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "xpu"],
help="The device type.",
)
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", "--kv-cache-dtype",
type=str, type=str,
...@@ -252,17 +252,6 @@ class ServerArgs: ...@@ -252,17 +252,6 @@ class ServerArgs:
choices=["auto", "fp8_e5m2"], choices=["auto", "fp8_e5m2"],
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.', help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
) )
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
)
parser.add_argument(
"--context-length",
type=int,
default=ServerArgs.context_length,
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
)
parser.add_argument( parser.add_argument(
"--quantization", "--quantization",
type=str, type=str,
...@@ -278,6 +267,19 @@ class ServerArgs: ...@@ -278,6 +267,19 @@ class ServerArgs:
], ],
help="The quantization method.", help="The quantization method.",
) )
parser.add_argument(
"--context-length",
type=int,
default=ServerArgs.context_length,
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "xpu"],
help="The device type.",
)
parser.add_argument( parser.add_argument(
"--served-model-name", "--served-model-name",
type=str, type=str,
...@@ -440,7 +442,23 @@ class ServerArgs: ...@@ -440,7 +442,23 @@ class ServerArgs:
default=ServerArgs.json_model_override_args, default=ServerArgs.json_model_override_args,
) )
# Optimization/debug options # LoRA
parser.add_argument(
"--lora-paths",
type=str,
nargs="*",
default=None,
action=LoRAPathAction,
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
)
parser.add_argument(
"--max-loras-per-batch",
type=int,
default=8,
help="Maximum number of adapters for a running batch, include base-only request",
)
# Kernel backend
parser.add_argument( parser.add_argument(
"--attention-backend", "--attention-backend",
type=str, type=str,
...@@ -455,6 +473,8 @@ class ServerArgs: ...@@ -455,6 +473,8 @@ class ServerArgs:
default=ServerArgs.sampling_backend, default=ServerArgs.sampling_backend,
help="Choose the kernels for sampling layers.", help="Choose the kernels for sampling layers.",
) )
# Optimization/debug options
parser.add_argument( parser.add_argument(
"--disable-flashinfer", "--disable-flashinfer",
action="store_true", action="store_true",
...@@ -501,6 +521,11 @@ class ServerArgs: ...@@ -501,6 +521,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.", help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
) )
parser.add_argument(
"--disable-penalizer",
action="store_true",
help="Disable the logit penalizer (e.g., frequency and repetition penalty).",
)
parser.add_argument( parser.add_argument(
"--enable-mixed-chunk", "--enable-mixed-chunk",
action="store_true", action="store_true",
...@@ -534,27 +559,6 @@ class ServerArgs: ...@@ -534,27 +559,6 @@ class ServerArgs:
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels.", "This only affects Triton attention kernels.",
) )
parser.add_argument(
"--efficient-weight-load",
action="store_true",
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
)
# LoRA options
parser.add_argument(
"--lora-paths",
type=str,
nargs="*",
default=None,
action=LoRAPathAction,
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
)
parser.add_argument(
"--max-loras-per-batch",
type=int,
default=8,
help="Maximum number of adapters for a running batch, include base-only request",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment