Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9da5a60b
Unverified
Commit
9da5a60b
authored
Oct 12, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 12, 2024
Browse files
Add an option to disable penalizer (#1651)
parent
69aa937a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
111 additions
and
90 deletions
+111
-90
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+8
-6
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+47
-35
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+52
-48
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
9da5a60b
...
...
@@ -531,7 +531,9 @@ class ScheduleBatch:
self
.
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
self
.
extend_logprob_start_lens
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
vocab_size
)
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
vocab_size
,
global_server_args_dict
[
"disable_penalizer"
]
)
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
self
.
forward_mode
=
ForwardMode
.
MIXED
...
...
python/sglang/srt/managers/scheduler.py
View file @
9da5a60b
...
...
@@ -671,9 +671,10 @@ class Scheduler:
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
if
self
.
is_generation
:
logits_output
,
next_token_ids
=
result
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
if
batch
.
sampling_info
.
penalizer_orchestrator
:
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
if
logits_output
:
# Move logprobs to cpu
...
...
@@ -755,9 +756,10 @@ class Scheduler:
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
logits_output
,
next_token_ids
=
result
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
if
batch
.
sampling_info
.
penalizer_orchestrator
:
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
# Move logprobs to cpu
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
9da5a60b
...
...
@@ -119,6 +119,7 @@ class ModelRunner:
"triton_attention_reduce_in_fp32"
:
server_args
.
triton_attention_reduce_in_fp32
,
"disable_mla"
:
server_args
.
disable_mla
,
"torchao_config"
:
server_args
.
torchao_config
,
"disable_penalizer"
:
server_args
.
disable_penalizer
,
}
)
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
9da5a60b
from
__future__
import
annotations
import
dataclasses
from
typing
import
TYPE_CHECKING
,
List
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
...
...
@@ -33,15 +33,20 @@ class SamplingBatchInfo:
regex_fsm_states
:
List
[
int
]
=
None
# Penalizer
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
linear_penalties
:
torch
.
Tensor
=
None
scaling_penalties
:
torch
.
Tensor
=
None
penalizer_orchestrator
:
Optional
[
penaltylib
.
BatchedPenalizerOrchestrator
]
=
None
linear_penalties
:
Optional
[
torch
.
Tensor
]
=
None
scaling_penalties
:
Optional
[
torch
.
Tensor
]
=
None
# Device
device
:
str
=
"cuda"
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
,
disable_penalizer
:
bool
,
):
reqs
=
batch
.
reqs
with
batch
.
input_ids
.
device
:
temperatures
=
torch
.
tensor
(
...
...
@@ -76,17 +81,20 @@ class SamplingBatchInfo:
# While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge()} cases as well.
ret
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
vocab_size
=
vocab_size
,
batch
=
batch
,
device
=
batch
.
input_ids
.
device
,
Penalizers
=
{
penaltylib
.
BatchedFrequencyPenalizer
,
penaltylib
.
BatchedMinNewTokensPenalizer
,
penaltylib
.
BatchedPresencePenalizer
,
penaltylib
.
BatchedRepetitionPenalizer
,
},
)
if
disable_penalizer
:
ret
.
penalizer_orchestrator
=
None
else
:
ret
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
vocab_size
=
vocab_size
,
batch
=
batch
,
device
=
batch
.
input_ids
.
device
,
Penalizers
=
{
penaltylib
.
BatchedFrequencyPenalizer
,
penaltylib
.
BatchedMinNewTokensPenalizer
,
penaltylib
.
BatchedPresencePenalizer
,
penaltylib
.
BatchedRepetitionPenalizer
,
},
)
# Handle logit bias but only allocate when needed
ret
.
logit_bias
=
None
...
...
@@ -97,6 +105,9 @@ class SamplingBatchInfo:
return
len
(
self
.
temperatures
)
def
update_penalties
(
self
):
if
not
self
.
penalizer_orchestrator
:
return
self
.
scaling_penalties
=
None
self
.
linear_penalties
=
None
...
...
@@ -117,26 +128,26 @@ class SamplingBatchInfo:
def
update_regex_vocab_mask
(
self
):
has_regex
=
self
.
regex_fsms
and
any
(
regex_fsm
for
regex_fsm
in
self
.
regex_fsms
)
# Reset the vocab mask
self
.
vocab_mask
=
None
if
has_regex
:
self
.
vocab_mask
=
torch
.
zeros
(
len
(
self
.
temperatures
),
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
self
.
device
,
)
for
i
,
regex_fsm
in
enumerate
(
self
.
regex_fsms
):
if
regex_fsm
is
not
None
:
self
.
vocab_mask
[
i
].
fill_
(
1
)
self
.
vocab_mask
[
i
][
regex_fsm
.
get_next_instruction
(
self
.
regex_fsm_states
[
i
]).
tokens
]
=
0
if
not
has_regex
:
self
.
vocab_mask
=
None
return
self
.
vocab_mask
=
torch
.
zeros
(
len
(
self
.
temperatures
),
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
self
.
device
,
)
for
i
,
regex_fsm
in
enumerate
(
self
.
regex_fsms
):
if
regex_fsm
is
not
None
:
self
.
vocab_mask
[
i
].
fill_
(
1
)
self
.
vocab_mask
[
i
][
regex_fsm
.
get_next_instruction
(
self
.
regex_fsm_states
[
i
]).
tokens
]
=
0
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
if
self
.
penalizer_orchestrator
:
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
for
item
in
[
"temperatures"
,
...
...
@@ -175,7 +186,8 @@ class SamplingBatchInfo:
return
None
def
merge_batch
(
self
,
other
:
"SamplingBatchInfo"
):
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
if
self
.
penalizer_orchestrator
:
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
for
item
in
[
"temperatures"
,
...
...
python/sglang/srt/server_args.py
View file @
9da5a60b
...
...
@@ -35,12 +35,12 @@ class ServerArgs:
tokenizer_mode
:
str
=
"auto"
skip_tokenizer_init
:
bool
=
False
load_format
:
str
=
"auto"
trust_remote_code
:
bool
=
True
dtype
:
str
=
"auto"
device
:
str
=
"cuda"
kv_cache_dtype
:
str
=
"auto"
trust_remote_code
:
bool
=
True
context_length
:
Optional
[
int
]
=
None
quantization
:
Optional
[
str
]
=
None
context_length
:
Optional
[
int
]
=
None
device
:
str
=
"cuda"
served_model_name
:
Optional
[
str
]
=
None
chat_template
:
Optional
[
str
]
=
None
is_embedding
:
bool
=
False
...
...
@@ -86,10 +86,15 @@ class ServerArgs:
# Model override args in JSON
json_model_override_args
:
str
=
"{}"
# Optimization/debug options
# LoRA
lora_paths
:
Optional
[
List
[
str
]]
=
None
max_loras_per_batch
:
int
=
8
# Kernel backend
attention_backend
:
Optional
[
str
]
=
None
sampling_backend
:
Optional
[
str
]
=
None
# Optimization/debug options
disable_flashinfer
:
bool
=
False
disable_flashinfer_sampling
:
bool
=
False
disable_radix_cache
:
bool
=
False
...
...
@@ -99,6 +104,7 @@ class ServerArgs:
disable_disk_cache
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
disable_mla
:
bool
=
False
disable_penalizer
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_torch_compile
:
bool
=
False
max_torch_compile_bs
:
int
=
32
...
...
@@ -106,10 +112,6 @@ class ServerArgs:
enable_p2p_check
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
# LoRA
lora_paths
:
Optional
[
List
[
str
]]
=
None
max_loras_per_batch
:
int
=
8
def
__post_init__
(
self
):
# Set missing default values
if
self
.
tokenizer_path
is
None
:
...
...
@@ -224,6 +226,11 @@ class ServerArgs:
'"dummy" will initialize the weights with random values, '
"which is mainly for profiling."
,
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
,
help
=
"Whether or not to allow for custom models defined on the Hub in their own modeling files."
,
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
...
...
@@ -238,13 +245,6 @@ class ServerArgs:
'* "float" is shorthand for FP32 precision.
\n
'
'* "float32" for FP32 precision.'
,
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cuda"
,
"xpu"
],
help
=
"The device type."
,
)
parser
.
add_argument
(
"--kv-cache-dtype"
,
type
=
str
,
...
...
@@ -252,17 +252,6 @@ class ServerArgs:
choices
=
[
"auto"
,
"fp8_e5m2"
],
help
=
'Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.'
,
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
,
help
=
"Whether or not to allow for custom models defined on the Hub in their own modeling files."
,
)
parser
.
add_argument
(
"--context-length"
,
type
=
int
,
default
=
ServerArgs
.
context_length
,
help
=
"The model's maximum context length. Defaults to None (will use the value from the model's config.json instead)."
,
)
parser
.
add_argument
(
"--quantization"
,
type
=
str
,
...
...
@@ -278,6 +267,19 @@ class ServerArgs:
],
help
=
"The quantization method."
,
)
parser
.
add_argument
(
"--context-length"
,
type
=
int
,
default
=
ServerArgs
.
context_length
,
help
=
"The model's maximum context length. Defaults to None (will use the value from the model's config.json instead)."
,
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cuda"
,
"xpu"
],
help
=
"The device type."
,
)
parser
.
add_argument
(
"--served-model-name"
,
type
=
str
,
...
...
@@ -440,7 +442,23 @@ class ServerArgs:
default
=
ServerArgs
.
json_model_override_args
,
)
# Optimization/debug options
# LoRA
parser
.
add_argument
(
"--lora-paths"
,
type
=
str
,
nargs
=
"*"
,
default
=
None
,
action
=
LoRAPathAction
,
help
=
"The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}"
,
)
parser
.
add_argument
(
"--max-loras-per-batch"
,
type
=
int
,
default
=
8
,
help
=
"Maximum number of adapters for a running batch, include base-only request"
,
)
# Kernel backend
parser
.
add_argument
(
"--attention-backend"
,
type
=
str
,
...
...
@@ -455,6 +473,8 @@ class ServerArgs:
default
=
ServerArgs
.
sampling_backend
,
help
=
"Choose the kernels for sampling layers."
,
)
# Optimization/debug options
parser
.
add_argument
(
"--disable-flashinfer"
,
action
=
"store_true"
,
...
...
@@ -501,6 +521,11 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Disable Multi-head Latent Attention (MLA) for DeepSeek-V2."
,
)
parser
.
add_argument
(
"--disable-penalizer"
,
action
=
"store_true"
,
help
=
"Disable the logit penalizer (e.g., frequency and repetition penalty)."
,
)
parser
.
add_argument
(
"--enable-mixed-chunk"
,
action
=
"store_true"
,
...
...
@@ -534,27 +559,6 @@ class ServerArgs:
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels."
,
)
parser
.
add_argument
(
"--efficient-weight-load"
,
action
=
"store_true"
,
help
=
"Turn on memory efficient weight loading with quantization (quantize per layer during loading)."
,
)
# LoRA options
parser
.
add_argument
(
"--lora-paths"
,
type
=
str
,
nargs
=
"*"
,
default
=
None
,
action
=
LoRAPathAction
,
help
=
"The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}"
,
)
parser
.
add_argument
(
"--max-loras-per-batch"
,
type
=
int
,
default
=
8
,
help
=
"Maximum number of adapters for a running batch, include base-only request"
,
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment