Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9da5a60b
Unverified
Commit
9da5a60b
authored
Oct 12, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 12, 2024
Browse files
Add an option to disable penalizer (#1651)
parent
69aa937a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
111 additions
and
90 deletions
+111
-90
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+8
-6
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+47
-35
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+52
-48
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
9da5a60b
...
@@ -531,7 +531,9 @@ class ScheduleBatch:
...
@@ -531,7 +531,9 @@ class ScheduleBatch:
self
.
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
self
.
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
self
.
extend_logprob_start_lens
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
self
.
extend_logprob_start_lens
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
vocab_size
)
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
vocab_size
,
global_server_args_dict
[
"disable_penalizer"
]
)
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
self
.
forward_mode
=
ForwardMode
.
MIXED
self
.
forward_mode
=
ForwardMode
.
MIXED
...
...
python/sglang/srt/managers/scheduler.py
View file @
9da5a60b
...
@@ -671,9 +671,10 @@ class Scheduler:
...
@@ -671,9 +671,10 @@ class Scheduler:
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
if
self
.
is_generation
:
if
self
.
is_generation
:
logits_output
,
next_token_ids
=
result
logits_output
,
next_token_ids
=
result
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
if
batch
.
sampling_info
.
penalizer_orchestrator
:
next_token_ids
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
)
next_token_ids
)
if
logits_output
:
if
logits_output
:
# Move logprobs to cpu
# Move logprobs to cpu
...
@@ -755,9 +756,10 @@ class Scheduler:
...
@@ -755,9 +756,10 @@ class Scheduler:
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
logits_output
,
next_token_ids
=
result
logits_output
,
next_token_ids
=
result
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
if
batch
.
sampling_info
.
penalizer_orchestrator
:
next_token_ids
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
)
next_token_ids
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
# Move logprobs to cpu
# Move logprobs to cpu
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
9da5a60b
...
@@ -119,6 +119,7 @@ class ModelRunner:
...
@@ -119,6 +119,7 @@ class ModelRunner:
"triton_attention_reduce_in_fp32"
:
server_args
.
triton_attention_reduce_in_fp32
,
"triton_attention_reduce_in_fp32"
:
server_args
.
triton_attention_reduce_in_fp32
,
"disable_mla"
:
server_args
.
disable_mla
,
"disable_mla"
:
server_args
.
disable_mla
,
"torchao_config"
:
server_args
.
torchao_config
,
"torchao_config"
:
server_args
.
torchao_config
,
"disable_penalizer"
:
server_args
.
disable_penalizer
,
}
}
)
)
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
9da5a60b
from
__future__
import
annotations
from
__future__
import
annotations
import
dataclasses
import
dataclasses
from
typing
import
TYPE_CHECKING
,
List
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
import
torch
...
@@ -33,15 +33,20 @@ class SamplingBatchInfo:
...
@@ -33,15 +33,20 @@ class SamplingBatchInfo:
regex_fsm_states
:
List
[
int
]
=
None
regex_fsm_states
:
List
[
int
]
=
None
# Penalizer
# Penalizer
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
penalizer_orchestrator
:
Optional
[
penaltylib
.
BatchedPenalizerOrchestrator
]
=
None
linear_penalties
:
torch
.
Tensor
=
None
linear_penalties
:
Optional
[
torch
.
Tensor
]
=
None
scaling_penalties
:
torch
.
Tensor
=
None
scaling_penalties
:
Optional
[
torch
.
Tensor
]
=
None
# Device
# Device
device
:
str
=
"cuda"
device
:
str
=
"cuda"
@
classmethod
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
,
disable_penalizer
:
bool
,
):
reqs
=
batch
.
reqs
reqs
=
batch
.
reqs
with
batch
.
input_ids
.
device
:
with
batch
.
input_ids
.
device
:
temperatures
=
torch
.
tensor
(
temperatures
=
torch
.
tensor
(
...
@@ -76,17 +81,20 @@ class SamplingBatchInfo:
...
@@ -76,17 +81,20 @@ class SamplingBatchInfo:
# While we choose not to even create the class instances if they are not required, this
# While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge()} cases as well.
# handle {filter_batch()} and {merge()} cases as well.
ret
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
if
disable_penalizer
:
vocab_size
=
vocab_size
,
ret
.
penalizer_orchestrator
=
None
batch
=
batch
,
else
:
device
=
batch
.
input_ids
.
device
,
ret
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
Penalizers
=
{
vocab_size
=
vocab_size
,
penaltylib
.
BatchedFrequencyPenalizer
,
batch
=
batch
,
penaltylib
.
BatchedMinNewTokensPenalizer
,
device
=
batch
.
input_ids
.
device
,
penaltylib
.
BatchedPresencePenalizer
,
Penalizers
=
{
penaltylib
.
BatchedRepetitionPenalizer
,
penaltylib
.
BatchedFrequencyPenalizer
,
},
penaltylib
.
BatchedMinNewTokensPenalizer
,
)
penaltylib
.
BatchedPresencePenalizer
,
penaltylib
.
BatchedRepetitionPenalizer
,
},
)
# Handle logit bias but only allocate when needed
# Handle logit bias but only allocate when needed
ret
.
logit_bias
=
None
ret
.
logit_bias
=
None
...
@@ -97,6 +105,9 @@ class SamplingBatchInfo:
...
@@ -97,6 +105,9 @@ class SamplingBatchInfo:
return
len
(
self
.
temperatures
)
return
len
(
self
.
temperatures
)
def
update_penalties
(
self
):
def
update_penalties
(
self
):
if
not
self
.
penalizer_orchestrator
:
return
self
.
scaling_penalties
=
None
self
.
scaling_penalties
=
None
self
.
linear_penalties
=
None
self
.
linear_penalties
=
None
...
@@ -117,26 +128,26 @@ class SamplingBatchInfo:
...
@@ -117,26 +128,26 @@ class SamplingBatchInfo:
def
update_regex_vocab_mask
(
self
):
def
update_regex_vocab_mask
(
self
):
has_regex
=
self
.
regex_fsms
and
any
(
regex_fsm
for
regex_fsm
in
self
.
regex_fsms
)
has_regex
=
self
.
regex_fsms
and
any
(
regex_fsm
for
regex_fsm
in
self
.
regex_fsms
)
if
not
has_regex
:
# Reset the vocab mask
self
.
vocab_mask
=
None
self
.
vocab_mask
=
None
return
if
has_regex
:
self
.
vocab_mask
=
torch
.
zeros
(
self
.
vocab_mask
=
torch
.
zeros
(
len
(
self
.
temperatures
),
len
(
self
.
temperatures
),
self
.
vocab_size
,
self
.
vocab_size
,
dtype
=
torch
.
bool
,
dtype
=
torch
.
bool
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
for
i
,
regex_fsm
in
enumerate
(
self
.
regex_fsms
):
for
i
,
regex_fsm
in
enumerate
(
self
.
regex_fsms
):
if
regex_fsm
is
not
None
:
if
regex_fsm
is
not
None
:
self
.
vocab_mask
[
i
].
fill_
(
1
)
self
.
vocab_mask
[
i
].
fill_
(
1
)
self
.
vocab_mask
[
i
][
self
.
vocab_mask
[
i
][
regex_fsm
.
get_next_instruction
(
self
.
regex_fsm_states
[
i
]).
tokens
regex_fsm
.
get_next_instruction
(
self
.
regex_fsm_states
[
i
]).
tokens
]
=
0
]
=
0
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
if
self
.
penalizer_orchestrator
:
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
for
item
in
[
for
item
in
[
"temperatures"
,
"temperatures"
,
...
@@ -175,7 +186,8 @@ class SamplingBatchInfo:
...
@@ -175,7 +186,8 @@ class SamplingBatchInfo:
return
None
return
None
def
merge_batch
(
self
,
other
:
"SamplingBatchInfo"
):
def
merge_batch
(
self
,
other
:
"SamplingBatchInfo"
):
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
if
self
.
penalizer_orchestrator
:
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
for
item
in
[
for
item
in
[
"temperatures"
,
"temperatures"
,
...
...
python/sglang/srt/server_args.py
View file @
9da5a60b
...
@@ -35,12 +35,12 @@ class ServerArgs:
...
@@ -35,12 +35,12 @@ class ServerArgs:
tokenizer_mode
:
str
=
"auto"
tokenizer_mode
:
str
=
"auto"
skip_tokenizer_init
:
bool
=
False
skip_tokenizer_init
:
bool
=
False
load_format
:
str
=
"auto"
load_format
:
str
=
"auto"
trust_remote_code
:
bool
=
True
dtype
:
str
=
"auto"
dtype
:
str
=
"auto"
device
:
str
=
"cuda"
kv_cache_dtype
:
str
=
"auto"
kv_cache_dtype
:
str
=
"auto"
trust_remote_code
:
bool
=
True
context_length
:
Optional
[
int
]
=
None
quantization
:
Optional
[
str
]
=
None
quantization
:
Optional
[
str
]
=
None
context_length
:
Optional
[
int
]
=
None
device
:
str
=
"cuda"
served_model_name
:
Optional
[
str
]
=
None
served_model_name
:
Optional
[
str
]
=
None
chat_template
:
Optional
[
str
]
=
None
chat_template
:
Optional
[
str
]
=
None
is_embedding
:
bool
=
False
is_embedding
:
bool
=
False
...
@@ -86,10 +86,15 @@ class ServerArgs:
...
@@ -86,10 +86,15 @@ class ServerArgs:
# Model override args in JSON
# Model override args in JSON
json_model_override_args
:
str
=
"{}"
json_model_override_args
:
str
=
"{}"
# Optimization/debug options
# LoRA
lora_paths
:
Optional
[
List
[
str
]]
=
None
max_loras_per_batch
:
int
=
8
# Kernel backend
attention_backend
:
Optional
[
str
]
=
None
attention_backend
:
Optional
[
str
]
=
None
sampling_backend
:
Optional
[
str
]
=
None
sampling_backend
:
Optional
[
str
]
=
None
# Optimization/debug options
disable_flashinfer
:
bool
=
False
disable_flashinfer
:
bool
=
False
disable_flashinfer_sampling
:
bool
=
False
disable_flashinfer_sampling
:
bool
=
False
disable_radix_cache
:
bool
=
False
disable_radix_cache
:
bool
=
False
...
@@ -99,6 +104,7 @@ class ServerArgs:
...
@@ -99,6 +104,7 @@ class ServerArgs:
disable_disk_cache
:
bool
=
False
disable_disk_cache
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
disable_mla
:
bool
=
False
disable_mla
:
bool
=
False
disable_penalizer
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_torch_compile
:
bool
=
False
max_torch_compile_bs
:
int
=
32
max_torch_compile_bs
:
int
=
32
...
@@ -106,10 +112,6 @@ class ServerArgs:
...
@@ -106,10 +112,6 @@ class ServerArgs:
enable_p2p_check
:
bool
=
False
enable_p2p_check
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
# LoRA
lora_paths
:
Optional
[
List
[
str
]]
=
None
max_loras_per_batch
:
int
=
8
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Set missing default values
# Set missing default values
if
self
.
tokenizer_path
is
None
:
if
self
.
tokenizer_path
is
None
:
...
@@ -224,6 +226,11 @@ class ServerArgs:
...
@@ -224,6 +226,11 @@ class ServerArgs:
'"dummy" will initialize the weights with random values, '
'"dummy" will initialize the weights with random values, '
"which is mainly for profiling."
,
"which is mainly for profiling."
,
)
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
,
help
=
"Whether or not to allow for custom models defined on the Hub in their own modeling files."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--dtype"
,
"--dtype"
,
type
=
str
,
type
=
str
,
...
@@ -238,13 +245,6 @@ class ServerArgs:
...
@@ -238,13 +245,6 @@ class ServerArgs:
'* "float" is shorthand for FP32 precision.
\n
'
'* "float" is shorthand for FP32 precision.
\n
'
'* "float32" for FP32 precision.'
,
'* "float32" for FP32 precision.'
,
)
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cuda"
,
"xpu"
],
help
=
"The device type."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--kv-cache-dtype"
,
"--kv-cache-dtype"
,
type
=
str
,
type
=
str
,
...
@@ -252,17 +252,6 @@ class ServerArgs:
...
@@ -252,17 +252,6 @@ class ServerArgs:
choices
=
[
"auto"
,
"fp8_e5m2"
],
choices
=
[
"auto"
,
"fp8_e5m2"
],
help
=
'Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.'
,
help
=
'Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.'
,
)
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
,
help
=
"Whether or not to allow for custom models defined on the Hub in their own modeling files."
,
)
parser
.
add_argument
(
"--context-length"
,
type
=
int
,
default
=
ServerArgs
.
context_length
,
help
=
"The model's maximum context length. Defaults to None (will use the value from the model's config.json instead)."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--quantization"
,
"--quantization"
,
type
=
str
,
type
=
str
,
...
@@ -278,6 +267,19 @@ class ServerArgs:
...
@@ -278,6 +267,19 @@ class ServerArgs:
],
],
help
=
"The quantization method."
,
help
=
"The quantization method."
,
)
)
parser
.
add_argument
(
"--context-length"
,
type
=
int
,
default
=
ServerArgs
.
context_length
,
help
=
"The model's maximum context length. Defaults to None (will use the value from the model's config.json instead)."
,
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cuda"
,
"xpu"
],
help
=
"The device type."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--served-model-name"
,
"--served-model-name"
,
type
=
str
,
type
=
str
,
...
@@ -440,7 +442,23 @@ class ServerArgs:
...
@@ -440,7 +442,23 @@ class ServerArgs:
default
=
ServerArgs
.
json_model_override_args
,
default
=
ServerArgs
.
json_model_override_args
,
)
)
# Optimization/debug options
# LoRA
parser
.
add_argument
(
"--lora-paths"
,
type
=
str
,
nargs
=
"*"
,
default
=
None
,
action
=
LoRAPathAction
,
help
=
"The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}"
,
)
parser
.
add_argument
(
"--max-loras-per-batch"
,
type
=
int
,
default
=
8
,
help
=
"Maximum number of adapters for a running batch, include base-only request"
,
)
# Kernel backend
parser
.
add_argument
(
parser
.
add_argument
(
"--attention-backend"
,
"--attention-backend"
,
type
=
str
,
type
=
str
,
...
@@ -455,6 +473,8 @@ class ServerArgs:
...
@@ -455,6 +473,8 @@ class ServerArgs:
default
=
ServerArgs
.
sampling_backend
,
default
=
ServerArgs
.
sampling_backend
,
help
=
"Choose the kernels for sampling layers."
,
help
=
"Choose the kernels for sampling layers."
,
)
)
# Optimization/debug options
parser
.
add_argument
(
parser
.
add_argument
(
"--disable-flashinfer"
,
"--disable-flashinfer"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
@@ -501,6 +521,11 @@ class ServerArgs:
...
@@ -501,6 +521,11 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Disable Multi-head Latent Attention (MLA) for DeepSeek-V2."
,
help
=
"Disable Multi-head Latent Attention (MLA) for DeepSeek-V2."
,
)
)
parser
.
add_argument
(
"--disable-penalizer"
,
action
=
"store_true"
,
help
=
"Disable the logit penalizer (e.g., frequency and repetition penalty)."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-mixed-chunk"
,
"--enable-mixed-chunk"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
@@ -534,27 +559,6 @@ class ServerArgs:
...
@@ -534,27 +559,6 @@ class ServerArgs:
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels."
,
"This only affects Triton attention kernels."
,
)
)
parser
.
add_argument
(
"--efficient-weight-load"
,
action
=
"store_true"
,
help
=
"Turn on memory efficient weight loading with quantization (quantize per layer during loading)."
,
)
# LoRA options
parser
.
add_argument
(
"--lora-paths"
,
type
=
str
,
nargs
=
"*"
,
default
=
None
,
action
=
LoRAPathAction
,
help
=
"The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}"
,
)
parser
.
add_argument
(
"--max-loras-per-batch"
,
type
=
int
,
default
=
8
,
help
=
"Maximum number of adapters for a running batch, include base-only request"
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment