Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
899cf5c4
Unverified
Commit
899cf5c4
authored
Sep 15, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 15, 2024
Browse files
Remove deprecated configs (#1431)
parent
e79f6cd7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
13 additions
and
24 deletions
+13
-24
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+7
-2
python/sglang/global_config.py
python/sglang/global_config.py
+5
-13
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+0
-3
python/sglang/srt/layers/attention_backend.py
python/sglang/srt/layers/attention_backend.py
+1
-5
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+0
-1
No files found.
python/sglang/bench_latency.py
View file @
899cf5c4
...
@@ -63,7 +63,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
...
@@ -63,7 +63,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
suppress_other_loggers
from
sglang.srt.utils
import
kill_child_process
,
suppress_other_loggers
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -502,4 +502,9 @@ if __name__ == "__main__":
...
@@ -502,4 +502,9 @@ if __name__ == "__main__":
format
=
"%(message)s"
,
format
=
"%(message)s"
,
)
)
main
(
server_args
,
bench_args
)
try
:
main
(
server_args
,
bench_args
)
except
Exception
as
e
:
raise
e
finally
:
kill_child_process
(
os
.
getpid
(),
including_parent
=
False
)
python/sglang/global_config.py
View file @
899cf5c4
"""Global configurations"""
"""Global configurations"""
import
os
class
GlobalConfig
:
class
GlobalConfig
:
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -16,30 +18,20 @@ class GlobalConfig:
...
@@ -16,30 +18,20 @@ class GlobalConfig:
self
.
base_min_new_token_ratio
=
0.1
self
.
base_min_new_token_ratio
=
0.1
self
.
new_token_ratio_decay
=
0.001
self
.
new_token_ratio_decay
=
0.001
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
# This can improve the speed for large batch sizes during prefill.
self
.
layer_sync_threshold
=
8192
# Runtime constants: others
# Runtime constants: others
self
.
num_continue_decode_steps
=
10
self
.
num_continue_decode_steps
=
10
self
.
retract_decode_steps
=
20
self
.
retract_decode_steps
=
20
self
.
flashinfer_workspace_size
=
384
*
1024
*
1024
self
.
flashinfer_workspace_size
=
os
.
environ
.
get
(
"FLASHINFER_WORKSPACE_SIZE"
,
384
*
1024
*
1024
)
# Output tokenization configs
# Output tokenization configs
self
.
skip_special_tokens_in_output
=
True
self
.
skip_special_tokens_in_output
=
True
self
.
spaces_between_special_tokens_in_out
=
True
self
.
spaces_between_special_tokens_in_out
=
True
# Interpreter optimization configs
# Interpreter optimization configs
self
.
eager_fill_image
=
False
self
.
enable_precache_with_tracing
=
True
self
.
enable_precache_with_tracing
=
True
self
.
enable_parallel_encoding
=
True
self
.
enable_parallel_encoding
=
True
self
.
enable_parallel_decoding
=
True
# Deprecated
# Choices: ["no_adjust", "adjust_cache"]
# no_adjust: Do not adjust the position embedding of KV cache.
# adjust_cache: Adjust the position embedding of KV cache.
self
.
concate_and_append_mode
=
"no_adjust"
global_config
=
GlobalConfig
()
global_config
=
GlobalConfig
()
python/sglang/lang/interpreter.py
View file @
899cf5c4
...
@@ -434,9 +434,6 @@ class StreamExecutor:
...
@@ -434,9 +434,6 @@ class StreamExecutor:
self
.
cur_images
.
append
((
path
,
base64_data
))
self
.
cur_images
.
append
((
path
,
base64_data
))
self
.
text_
+=
self
.
chat_template
.
image_token
self
.
text_
+=
self
.
chat_template
.
image_token
# if global_config.eager_fill_image:
# self.backend.fill_image(self)
def
_spec_gen
(
self
,
sampling_params
):
def
_spec_gen
(
self
,
sampling_params
):
stop
=
sampling_params
.
stop
stop
=
sampling_params
.
stop
max_new_tokens
=
sampling_params
.
max_new_tokens
max_new_tokens
=
sampling_params
.
max_new_tokens
...
...
python/sglang/srt/layers/attention_backend.py
View file @
899cf5c4
...
@@ -150,7 +150,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -150,7 +150,7 @@ class FlashInferAttnBackend(AttentionBackend):
# Some heuristics to check whether to use ragged forward
# Some heuristics to check whether to use ragged forward
use_ragged
=
False
use_ragged
=
False
if
(
if
(
int
(
torch
.
sum
(
input_metadata
.
seq_lens
))
>
4096
torch
.
sum
(
input_metadata
.
seq_lens
)
.
item
(
)
>
=
4096
and
self
.
model_runner
.
sliding_window_size
is
None
and
self
.
model_runner
.
sliding_window_size
is
None
):
):
use_ragged
=
True
use_ragged
=
True
...
@@ -301,10 +301,6 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -301,10 +301,6 @@ class FlashInferAttnBackend(AttentionBackend):
layer
.
layer_id
,
input_metadata
.
out_cache_loc
,
k
,
v
layer
.
layer_id
,
input_metadata
.
out_cache_loc
,
k
,
v
)
)
if
total_num_tokens
>=
global_config
.
layer_sync_threshold
:
# TODO: Revisit this. Why is this synchronize needed?
torch
.
cuda
.
synchronize
()
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
...
...
python/sglang/test/test_utils.py
View file @
899cf5c4
...
@@ -304,7 +304,6 @@ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
...
@@ -304,7 +304,6 @@ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
def
select_sglang_backend
(
args
:
argparse
.
Namespace
):
def
select_sglang_backend
(
args
:
argparse
.
Namespace
):
if
args
.
backend
.
startswith
(
"srt"
):
if
args
.
backend
.
startswith
(
"srt"
):
if
args
.
backend
==
"srt-no-parallel"
:
if
args
.
backend
==
"srt-no-parallel"
:
global_config
.
enable_parallel_decoding
=
False
global_config
.
enable_parallel_encoding
=
False
global_config
.
enable_parallel_encoding
=
False
backend
=
RuntimeEndpoint
(
f
"
{
args
.
host
}
:
{
args
.
port
}
"
)
backend
=
RuntimeEndpoint
(
f
"
{
args
.
host
}
:
{
args
.
port
}
"
)
elif
args
.
backend
.
startswith
(
"gpt-"
):
elif
args
.
backend
.
startswith
(
"gpt-"
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment