Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9da5a60b
"...utils/git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "35d644628fd31b1df00e2bc5b601b89c5fd335c6"
Unverified
Commit
9da5a60b
authored
Oct 12, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 12, 2024
Browse files
Add an option to disable penalizer (#1651)
parent
69aa937a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
111 additions
and
90 deletions
+111
-90
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+8
-6
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+47
-35
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+52
-48
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
9da5a60b
...
@@ -531,7 +531,9 @@ class ScheduleBatch:
...
@@ -531,7 +531,9 @@ class ScheduleBatch:
self
.
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
self
.
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
self
.
extend_logprob_start_lens
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
self
.
extend_logprob_start_lens
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
vocab_size
)
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
vocab_size
,
global_server_args_dict
[
"disable_penalizer"
]
)
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
self
.
forward_mode
=
ForwardMode
.
MIXED
self
.
forward_mode
=
ForwardMode
.
MIXED
...
...
python/sglang/srt/managers/scheduler.py
View file @
9da5a60b
...
@@ -671,9 +671,10 @@ class Scheduler:
...
@@ -671,9 +671,10 @@ class Scheduler:
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
if
self
.
is_generation
:
if
self
.
is_generation
:
logits_output
,
next_token_ids
=
result
logits_output
,
next_token_ids
=
result
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
if
batch
.
sampling_info
.
penalizer_orchestrator
:
next_token_ids
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
)
next_token_ids
)
if
logits_output
:
if
logits_output
:
# Move logprobs to cpu
# Move logprobs to cpu
...
@@ -755,9 +756,10 @@ class Scheduler:
...
@@ -755,9 +756,10 @@ class Scheduler:
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
logits_output
,
next_token_ids
=
result
logits_output
,
next_token_ids
=
result
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
if
batch
.
sampling_info
.
penalizer_orchestrator
:
next_token_ids
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
)
next_token_ids
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
# Move logprobs to cpu
# Move logprobs to cpu
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
9da5a60b
...
@@ -119,6 +119,7 @@ class ModelRunner:
...
@@ -119,6 +119,7 @@ class ModelRunner:
"triton_attention_reduce_in_fp32"
:
server_args
.
triton_attention_reduce_in_fp32
,
"triton_attention_reduce_in_fp32"
:
server_args
.
triton_attention_reduce_in_fp32
,
"disable_mla"
:
server_args
.
disable_mla
,
"disable_mla"
:
server_args
.
disable_mla
,
"torchao_config"
:
server_args
.
torchao_config
,
"torchao_config"
:
server_args
.
torchao_config
,
"disable_penalizer"
:
server_args
.
disable_penalizer
,
}
}
)
)
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
9da5a60b
from
__future__
import
annotations
from
__future__
import
annotations
import
dataclasses
import
dataclasses
from
typing
import
TYPE_CHECKING
,
List
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
import
torch
...
@@ -33,15 +33,20 @@ class SamplingBatchInfo:
...
@@ -33,15 +33,20 @@ class SamplingBatchInfo:
regex_fsm_states
:
List
[
int
]
=
None
regex_fsm_states
:
List
[
int
]
=
None
# Penalizer
# Penalizer
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
penalizer_orchestrator
:
Optional
[
penaltylib
.
BatchedPenalizerOrchestrator
]
=
None
linear_penalties
:
torch
.
Tensor
=
None
linear_penalties
:
Optional
[
torch
.
Tensor
]
=
None
scaling_penalties
:
torch
.
Tensor
=
None
scaling_penalties
:
Optional
[
torch
.
Tensor
]
=
None
# Device
# Device
device
:
str
=
"cuda"
device
:
str
=
"cuda"
@
classmethod
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
,
disable_penalizer
:
bool
,
):
reqs
=
batch
.
reqs
reqs
=
batch
.
reqs
with
batch
.
input_ids
.
device
:
with
batch
.
input_ids
.
device
:
temperatures
=
torch
.
tensor
(
temperatures
=
torch
.
tensor
(
...
@@ -76,17 +81,20 @@ class SamplingBatchInfo:
...
@@ -76,17 +81,20 @@ class SamplingBatchInfo:
# While we choose not to even create the class instances if they are not required, this
# While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge()} cases as well.
# handle {filter_batch()} and {merge()} cases as well.
ret
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
if
disable_penalizer
:
vocab_size
=
vocab_size
,
ret
.
penalizer_orchestrator
=
None
batch
=
batch
,
else
:
device
=
batch
.
input_ids
.
device
,
ret
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
Penalizers
=
{
vocab_size
=
vocab_size
,
penaltylib
.
BatchedFrequencyPenalizer
,
batch
=
batch
,
penaltylib
.
BatchedMinNewTokensPenalizer
,
device
=
batch
.
input_ids
.
device
,
penaltylib
.
BatchedPresencePenalizer
,
Penalizers
=
{
penaltylib
.
BatchedRepetitionPenalizer
,
penaltylib
.
BatchedFrequencyPenalizer
,
},
penaltylib
.
BatchedMinNewTokensPenalizer
,
)
penaltylib
.
BatchedPresencePenalizer
,
penaltylib
.
BatchedRepetitionPenalizer
,
},
)
# Handle logit bias but only allocate when needed
# Handle logit bias but only allocate when needed
ret
.
logit_bias
=
None
ret
.
logit_bias
=
None
...
@@ -97,6 +105,9 @@ class SamplingBatchInfo:
...
@@ -97,6 +105,9 @@ class SamplingBatchInfo:
return
len
(
self
.
temperatures
)
return
len
(
self
.
temperatures
)
def
update_penalties
(
self
):
def
update_penalties
(
self
):
if
not
self
.
penalizer_orchestrator
:
return
self
.
scaling_penalties
=
None
self
.
scaling_penalties
=
None
self
.
linear_penalties
=
None
self
.
linear_penalties
=
None
...
@@ -117,26 +128,26 @@ class SamplingBatchInfo:
...
@@ -117,26 +128,26 @@ class SamplingBatchInfo:
def
update_regex_vocab_mask
(
self
):
def
update_regex_vocab_mask
(
self
):
has_regex
=
self
.
regex_fsms
and
any
(
regex_fsm
for
regex_fsm
in
self
.
regex_fsms
)
has_regex
=
self
.
regex_fsms
and
any
(
regex_fsm
for
regex_fsm
in
self
.
regex_fsms
)
if
not
has_regex
:
# Reset the vocab mask
self
.
vocab_mask
=
None
self
.
vocab_mask
=
None
return
if
has_regex
:
self
.
vocab_mask
=
torch
.
zeros
(
self
.
vocab_mask
=
torch
.
zeros
(
len
(
self
.
temperatures
),
len
(
self
.
temperatures
),
self
.
vocab_size
,
self
.
vocab_size
,
dtype
=
torch
.
bool
,
dtype
=
torch
.
bool
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
for
i
,
regex_fsm
in
enumerate
(
self
.
regex_fsms
):
for
i
,
regex_fsm
in
enumerate
(
self
.
regex_fsms
):
if
regex_fsm
is
not
None
:
if
regex_fsm
is
not
None
:
self
.
vocab_mask
[
i
].
fill_
(
1
)
self
.
vocab_mask
[
i
].
fill_
(
1
)
self
.
vocab_mask
[
i
][
self
.
vocab_mask
[
i
][
regex_fsm
.
get_next_instruction
(
self
.
regex_fsm_states
[
i
]).
tokens
regex_fsm
.
get_next_instruction
(
self
.
regex_fsm_states
[
i
]).
tokens
]
=
0
]
=
0
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
if
self
.
penalizer_orchestrator
:
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
for
item
in
[
for
item
in
[
"temperatures"
,
"temperatures"
,
...
@@ -175,7 +186,8 @@ class SamplingBatchInfo:
...
@@ -175,7 +186,8 @@ class SamplingBatchInfo:
return
None
return
None
def
merge_batch
(
self
,
other
:
"SamplingBatchInfo"
):
def
merge_batch
(
self
,
other
:
"SamplingBatchInfo"
):
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
if
self
.
penalizer_orchestrator
:
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
for
item
in
[
for
item
in
[
"temperatures"
,
"temperatures"
,
...
...
python/sglang/srt/server_args.py
View file @
9da5a60b
...
@@ -35,12 +35,12 @@ class ServerArgs:
...
@@ -35,12 +35,12 @@ class ServerArgs:
tokenizer_mode
:
str
=
"auto"
tokenizer_mode
:
str
=
"auto"
skip_tokenizer_init
:
bool
=
False
skip_tokenizer_init
:
bool
=
False
load_format
:
str
=
"auto"
load_format
:
str
=
"auto"
trust_remote_code
:
bool
=
True
dtype
:
str
=
"auto"
dtype
:
str
=
"auto"
device
:
str
=
"cuda"
kv_cache_dtype
:
str
=
"auto"
kv_cache_dtype
:
str
=
"auto"
trust_remote_code
:
bool
=
True
context_length
:
Optional
[
int
]
=
None
quantization
:
Optional
[
str
]
=
None
quantization
:
Optional
[
str
]
=
None
context_length
:
Optional
[
int
]
=
None
device
:
str
=
"cuda"
served_model_name
:
Optional
[
str
]
=
None
served_model_name
:
Optional
[
str
]
=
None
chat_template
:
Optional
[
str
]
=
None
chat_template
:
Optional
[
str
]
=
None
is_embedding
:
bool
=
False
is_embedding
:
bool
=
False
...
@@ -86,10 +86,15 @@ class ServerArgs:
...
@@ -86,10 +86,15 @@ class ServerArgs:
# Model override args in JSON
# Model override args in JSON
json_model_override_args
:
str
=
"{}"
json_model_override_args
:
str
=
"{}"
# Optimization/debug options
# LoRA
lora_paths
:
Optional
[
List
[
str
]]
=
None
max_loras_per_batch
:
int
=
8
# Kernel backend
attention_backend
:
Optional
[
str
]
=
None
attention_backend
:
Optional
[
str
]
=
None
sampling_backend
:
Optional
[
str
]
=
None
sampling_backend
:
Optional
[
str
]
=
None
# Optimization/debug options
disable_flashinfer
:
bool
=
False
disable_flashinfer
:
bool
=
False
disable_flashinfer_sampling
:
bool
=
False
disable_flashinfer_sampling
:
bool
=
False
disable_radix_cache
:
bool
=
False
disable_radix_cache
:
bool
=
False
...
@@ -99,6 +104,7 @@ class ServerArgs:
...
@@ -99,6 +104,7 @@ class ServerArgs:
disable_disk_cache
:
bool
=
False
disable_disk_cache
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
disable_mla
:
bool
=
False
disable_mla
:
bool
=
False
disable_penalizer
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_torch_compile
:
bool
=
False
max_torch_compile_bs
:
int
=
32
max_torch_compile_bs
:
int
=
32
...
@@ -106,10 +112,6 @@ class ServerArgs:
...
@@ -106,10 +112,6 @@ class ServerArgs:
enable_p2p_check
:
bool
=
False
enable_p2p_check
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
# LoRA
lora_paths
:
Optional
[
List
[
str
]]
=
None
max_loras_per_batch
:
int
=
8
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Set missing default values
# Set missing default values
if
self
.
tokenizer_path
is
None
:
if
self
.
tokenizer_path
is
None
:
...
@@ -224,6 +226,11 @@ class ServerArgs:
...
@@ -224,6 +226,11 @@ class ServerArgs:
'"dummy" will initialize the weights with random values, '
'"dummy" will initialize the weights with random values, '
"which is mainly for profiling."
,
"which is mainly for profiling."
,
)
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
,
help
=
"Whether or not to allow for custom models defined on the Hub in their own modeling files."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--dtype"
,
"--dtype"
,
type
=
str
,
type
=
str
,
...
@@ -238,13 +245,6 @@ class ServerArgs:
...
@@ -238,13 +245,6 @@ class ServerArgs:
'* "float" is shorthand for FP32 precision.
\n
'
'* "float" is shorthand for FP32 precision.
\n
'
'* "float32" for FP32 precision.'
,
'* "float32" for FP32 precision.'
,
)
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cuda"
,
"xpu"
],
help
=
"The device type."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--kv-cache-dtype"
,
"--kv-cache-dtype"
,
type
=
str
,
type
=
str
,
...
@@ -252,17 +252,6 @@ class ServerArgs:
...
@@ -252,17 +252,6 @@ class ServerArgs:
choices
=
[
"auto"
,
"fp8_e5m2"
],
choices
=
[
"auto"
,
"fp8_e5m2"
],
help
=
'Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.'
,
help
=
'Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.'
,
)
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
,
help
=
"Whether or not to allow for custom models defined on the Hub in their own modeling files."
,
)
parser
.
add_argument
(
"--context-length"
,
type
=
int
,
default
=
ServerArgs
.
context_length
,
help
=
"The model's maximum context length. Defaults to None (will use the value from the model's config.json instead)."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--quantization"
,
"--quantization"
,
type
=
str
,
type
=
str
,
...
@@ -278,6 +267,19 @@ class ServerArgs:
...
@@ -278,6 +267,19 @@ class ServerArgs:
],
],
help
=
"The quantization method."
,
help
=
"The quantization method."
,
)
)
parser
.
add_argument
(
"--context-length"
,
type
=
int
,
default
=
ServerArgs
.
context_length
,
help
=
"The model's maximum context length. Defaults to None (will use the value from the model's config.json instead)."
,
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cuda"
,
"xpu"
],
help
=
"The device type."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--served-model-name"
,
"--served-model-name"
,
type
=
str
,
type
=
str
,
...
@@ -440,7 +442,23 @@ class ServerArgs:
...
@@ -440,7 +442,23 @@ class ServerArgs:
default
=
ServerArgs
.
json_model_override_args
,
default
=
ServerArgs
.
json_model_override_args
,
)
)
# Optimization/debug options
# LoRA
parser
.
add_argument
(
"--lora-paths"
,
type
=
str
,
nargs
=
"*"
,
default
=
None
,
action
=
LoRAPathAction
,
help
=
"The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}"
,
)
parser
.
add_argument
(
"--max-loras-per-batch"
,
type
=
int
,
default
=
8
,
help
=
"Maximum number of adapters for a running batch, include base-only request"
,
)
# Kernel backend
parser
.
add_argument
(
parser
.
add_argument
(
"--attention-backend"
,
"--attention-backend"
,
type
=
str
,
type
=
str
,
...
@@ -455,6 +473,8 @@ class ServerArgs:
...
@@ -455,6 +473,8 @@ class ServerArgs:
default
=
ServerArgs
.
sampling_backend
,
default
=
ServerArgs
.
sampling_backend
,
help
=
"Choose the kernels for sampling layers."
,
help
=
"Choose the kernels for sampling layers."
,
)
)
# Optimization/debug options
parser
.
add_argument
(
parser
.
add_argument
(
"--disable-flashinfer"
,
"--disable-flashinfer"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
@@ -501,6 +521,11 @@ class ServerArgs:
...
@@ -501,6 +521,11 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Disable Multi-head Latent Attention (MLA) for DeepSeek-V2."
,
help
=
"Disable Multi-head Latent Attention (MLA) for DeepSeek-V2."
,
)
)
parser
.
add_argument
(
"--disable-penalizer"
,
action
=
"store_true"
,
help
=
"Disable the logit penalizer (e.g., frequency and repetition penalty)."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-mixed-chunk"
,
"--enable-mixed-chunk"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
@@ -534,27 +559,6 @@ class ServerArgs:
...
@@ -534,27 +559,6 @@ class ServerArgs:
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels."
,
"This only affects Triton attention kernels."
,
)
)
parser
.
add_argument
(
"--efficient-weight-load"
,
action
=
"store_true"
,
help
=
"Turn on memory efficient weight loading with quantization (quantize per layer during loading)."
,
)
# LoRA options
parser
.
add_argument
(
"--lora-paths"
,
type
=
str
,
nargs
=
"*"
,
default
=
None
,
action
=
LoRAPathAction
,
help
=
"The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}"
,
)
parser
.
add_argument
(
"--max-loras-per-batch"
,
type
=
int
,
default
=
8
,
help
=
"Maximum number of adapters for a running batch, include base-only request"
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment