Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
90227800
Unverified
Commit
90227800
authored
Aug 25, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 25, 2024
Browse files
[Minor] Improve the function organization in TokenizerManager & improve loggers (#1208)
parent
30b4f771
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
137 additions
and
134 deletions
+137
-134
docs/en/hyperparameter_tuning.md
docs/en/hyperparameter_tuning.md
+1
-1
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+0
-11
python/sglang/srt/managers/controller_multi.py
python/sglang/srt/managers/controller_multi.py
+2
-5
python/sglang/srt/managers/controller_single.py
python/sglang/srt/managers/controller_single.py
+7
-7
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+12
-8
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+81
-75
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+8
-6
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+7
-10
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+2
-2
python/sglang/srt/server.py
python/sglang/srt/server.py
+3
-5
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+13
-3
No files found.
docs/en/hyperparameter_tuning.md
View file @
90227800
...
...
@@ -6,7 +6,7 @@ Achieving a large batch size is the most important thing for attaining high thro
When the server is running at full load, look for the following in the log:
```
[gpu=0]
Decode batch. #running-req: 233, #token: 370959, token usage: 0.82, gen throughput (token/s): 4594.01, #queue-req: 417```
```
Decode batch. #running-req: 233, #token: 370959, token usage: 0.82, gen throughput (token/s): 4594.01, #queue-req: 417```
### Tune Your Request Submission Speed
`#queue-req` indicates the number of requests in the queue. If you frequently see `#queue-req == 0`, it suggests you are bottlenecked by the request submission speed.
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
90227800
...
...
@@ -142,17 +142,6 @@ def get_tokenizer(
raise
ValueError
(
"Cannot use the fast tokenizer in slow tokenizer mode."
)
kwargs
[
"use_fast"
]
=
False
if
(
"llama"
in
tokenizer_name
.
lower
()
and
kwargs
.
get
(
"use_fast"
,
True
)
and
tokenizer_name
!=
_FAST_LLAMA_TOKENIZER
):
warnings
.
warn
(
"For some LLaMA V1 models, initializing the fast tokenizer may "
"take a long time. To reduce the initialization time, consider "
f
"using '
{
_FAST_LLAMA_TOKENIZER
}
' instead of the original "
"tokenizer."
)
try
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_name
,
...
...
python/sglang/srt/managers/controller_multi.py
View file @
90227800
...
...
@@ -35,7 +35,7 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput
,
)
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
kill_parent_process
from
sglang.srt.utils
import
configure_logger
,
kill_parent_process
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -193,10 +193,7 @@ def start_controller_process(
):
"""Start a controller process."""
logging
.
basicConfig
(
level
=
getattr
(
logging
,
server_args
.
log_level
.
upper
()),
format
=
"%(message)s"
,
)
configure_logger
(
server_args
)
try
:
controller
=
ControllerMulti
(
server_args
,
port_args
,
model_overide_args
)
...
...
python/sglang/srt/managers/controller_single.py
View file @
90227800
...
...
@@ -27,7 +27,7 @@ from sglang.srt.managers.tp_worker import (
launch_tp_servers
,
)
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
kill_parent_process
from
sglang.srt.utils
import
configure_logger
,
kill_parent_process
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -52,7 +52,7 @@ class ControllerSingle:
self
.
dp_worker_id
=
dp_worker_id
self
.
mp_queue
=
mp_queue
# Init communication
# Init
inter-process
communication
context
=
zmq
.
Context
(
2
)
if
not
self
.
is_dp_worker
:
...
...
@@ -133,11 +133,11 @@ def start_controller_process(
queue
:
multiprocessing
.
connection
.
Connection
=
None
,
):
"""Start a controller process."""
logg
ing
.
basicConfig
(
level
=
getattr
(
logging
,
server_args
.
log_level
.
upper
()),
format
=
"%(message)s"
,
)
if
is_data_parallel_worker
:
logg
er_prefix
=
f
" DP
{
dp_worker_id
}
TP0"
else
:
logger_prefix
=
" TP0"
configure_logger
(
server_args
,
prefix
=
logger_prefix
)
if
not
is_data_parallel_worker
:
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
90227800
...
...
@@ -56,6 +56,7 @@ class DetokenizerManager:
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
):
# Init inter-process communication
context
=
zmq
.
asyncio
.
Context
(
2
)
self
.
recv_from_router
=
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_router
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
detokenizer_port
}
"
)
...
...
@@ -75,10 +76,13 @@ class DetokenizerManager:
self
.
decode_status
=
{}
async
def
handle_loop
(
self
):
"""The event loop that handles requests"""
while
True
:
recv_obj
:
BatchTokenIDOut
=
await
self
.
recv_from_router
.
recv_pyobj
()
recv_obj
=
await
self
.
recv_from_router
.
recv_pyobj
()
if
isinstance
(
recv_obj
,
BatchEmbeddingOut
):
# If it is embedding model, no detokenization is needed.
self
.
send_to_tokenizer
.
send_pyobj
(
BatchEmbeddingOut
(
rids
=
recv_obj
.
rids
,
...
...
@@ -88,19 +92,18 @@ class DetokenizerManager:
)
)
continue
if
isinstance
(
recv_obj
,
UpdateWeightReqOutput
):
elif
isinstance
(
recv_obj
,
UpdateWeightReqOutput
):
# If it is a weight update request, no detokenization is needed.
self
.
send_to_tokenizer
.
send_pyobj
(
recv_obj
)
continue
elif
self
.
tokenizer
is
None
:
# If the tokenizer is skipped, no detokenization is needed
self
.
send_to_tokenizer
.
send_pyobj
(
recv_obj
)
continue
assert
isinstance
(
recv_obj
,
BatchTokenIDOut
)
bs
=
len
(
recv_obj
.
rids
)
if
self
.
tokenizer
is
None
:
# Send BatchTokenIDOut if no tokenizer init'ed.
self
.
send_to_tokenizer
.
send_pyobj
(
recv_obj
)
continue
# Initialize decode status
read_ids
,
surr_ids
=
[],
[]
for
i
in
range
(
bs
):
...
...
@@ -134,6 +137,7 @@ class DetokenizerManager:
spaces_between_special_tokens
=
recv_obj
.
spaces_between_special_tokens
[
0
],
)
# Incremental decoding
output_strs
=
[]
for
i
in
range
(
bs
):
s
=
self
.
decode_status
[
recv_obj
.
rids
[
i
]]
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
90227800
...
...
@@ -21,7 +21,7 @@ import dataclasses
import
logging
import
multiprocessing
as
mp
import
os
from
typing
import
Dict
,
List
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
transformers
...
...
@@ -80,6 +80,7 @@ class TokenizerManager:
):
self
.
server_args
=
server_args
# Init inter-process communication
context
=
zmq
.
asyncio
.
Context
(
2
)
self
.
recv_from_detokenizer
=
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_detokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
tokenizer_port
}
"
)
...
...
@@ -87,6 +88,7 @@ class TokenizerManager:
self
.
send_to_router
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_router
.
connect
(
f
"tcp://127.0.0.1:
{
port_args
.
controller_port
}
"
)
# Read model args
self
.
model_path
=
server_args
.
model_path
self
.
served_model_name
=
server_args
.
served_model_name
self
.
hf_config
=
get_config
(
...
...
@@ -104,6 +106,7 @@ class TokenizerManager:
else
:
self
.
context_len
=
get_context_length
(
self
.
hf_config
)
# Create tokenizer
if
server_args
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
processor
=
None
else
:
...
...
@@ -127,6 +130,7 @@ class TokenizerManager:
trust_remote_code
=
server_args
.
trust_remote_code
,
)
# Store states
self
.
to_create_loop
=
True
self
.
rid_to_state
:
Dict
[
str
,
ReqState
]
=
{}
...
...
@@ -134,63 +138,6 @@ class TokenizerManager:
self
.
model_update_lock
=
asyncio
.
Lock
()
self
.
model_update_result
=
None
async
def
get_pixel_values
(
self
,
image_data
,
aspect_ratio
=
None
):
aspect_ratio
=
(
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
None
)
if
aspect_ratio
is
None
else
aspect_ratio
)
grid_pinpoints
=
(
self
.
hf_config
.
image_grid_pinpoints
if
hasattr
(
self
.
hf_config
,
"image_grid_pinpoints"
)
and
"anyres"
in
aspect_ratio
else
None
)
if
isinstance
(
image_data
,
list
)
and
len
(
image_data
)
>
0
:
pixel_values
,
image_hash
,
image_size
=
[],
[],
[]
if
len
(
image_data
)
>
1
:
aspect_ratio
=
"pad"
# LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
for
img_data
in
image_data
:
pixel_v
,
image_h
,
image_s
=
await
self
.
_process_single_image
(
img_data
,
aspect_ratio
,
grid_pinpoints
)
pixel_values
.
append
(
pixel_v
)
image_hash
.
append
(
image_h
)
image_size
.
append
(
image_s
)
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
else
:
pixel_values
,
image_hash
,
image_size
=
await
self
.
_process_single_image
(
image_data
[
0
],
aspect_ratio
,
grid_pinpoints
)
image_hash
=
[
image_hash
]
image_size
=
[
image_size
]
elif
isinstance
(
image_data
,
str
):
pixel_values
,
image_hash
,
image_size
=
await
self
.
_process_single_image
(
image_data
,
aspect_ratio
,
grid_pinpoints
)
image_hash
=
[
image_hash
]
image_size
=
[
image_size
]
else
:
pixel_values
,
image_hash
,
image_size
=
None
,
None
,
None
return
pixel_values
,
image_hash
,
image_size
async
def
_process_single_image
(
self
,
image_data
,
aspect_ratio
,
grid_pinpoints
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
self
.
executor
,
get_pixel_values
,
image_data
,
aspect_ratio
,
grid_pinpoints
,
)
else
:
return
get_pixel_values
(
image_data
,
aspect_ratio
,
grid_pinpoints
,
self
.
processor
)
async
def
generate_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
=
None
):
...
...
@@ -198,7 +145,7 @@ class TokenizerManager:
self
.
create_handle_loop
()
while
self
.
model_update_lock
.
locked
():
await
asyncio
.
sleep
(
0
)
await
asyncio
.
sleep
(
0
.001
)
obj
.
post_init
()
is_single
=
obj
.
is_single
...
...
@@ -214,8 +161,8 @@ class TokenizerManager:
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
,
index
=
None
,
is_cache_for_prefill
=
False
,
index
:
Optional
[
int
]
=
None
,
is_cache_for_prefill
:
Optional
[
bool
]
=
False
,
):
if
not
is_cache_for_prefill
:
# The normal case with a single prompt
not_use_index
=
index
is
None
...
...
@@ -235,7 +182,7 @@ class TokenizerManager:
)
if
self
.
is_generation
:
pixel_values
,
image_hash
,
image_size
=
await
self
.
get_pixel_values
(
pixel_values
,
image_hash
,
image_size
=
await
self
.
_
get_pixel_values
(
obj
.
image_data
)
return_logprob
=
(
...
...
@@ -345,7 +292,7 @@ class TokenizerManager:
parallel_sample_num
=
obj
.
parallel_sample_num
if
parallel_sample_num
!=
1
:
# Send prefill requests to cache the common
input
# Send prefill requests to cache the common
prefix
parallel_sample_num
+=
1
input_id_result
=
[]
if
obj
.
input_ids
is
None
else
None
for
i
in
range
(
batch_size
):
...
...
@@ -436,7 +383,6 @@ class TokenizerManager:
)
# Then process the responses based on streaming option
is_stream
=
hasattr
(
obj
,
"stream"
)
and
obj
.
stream
tasks
=
[
asyncio
.
create_task
(
gen
.
__anext__
())
for
gen
in
generators
]
...
...
@@ -482,9 +428,9 @@ class TokenizerManager:
async
def
_get_pixel_values
(
self
,
image_data
):
if
isinstance
(
image_data
,
list
)
and
len
(
image_data
)
>
0
:
return
await
self
.
get_pixel_values
(
image_data
[
0
])
return
await
self
.
_
get_pixel_values
_internal
(
image_data
[
0
])
elif
isinstance
(
image_data
,
str
):
return
await
self
.
get_pixel_values
(
image_data
)
return
await
self
.
_
get_pixel_values
_internal
(
image_data
)
else
:
return
None
,
None
,
None
...
...
@@ -563,6 +509,13 @@ class TokenizerManager:
req
=
FlushCacheReq
()
self
.
send_to_router
.
send_pyobj
(
req
)
def
abort_request
(
self
,
rid
:
str
):
if
rid
not
in
self
.
rid_to_state
:
return
del
self
.
rid_to_state
[
rid
]
req
=
AbortReq
(
rid
)
self
.
send_to_router
.
send_pyobj
(
req
)
async
def
update_weights
(
self
,
obj
:
UpdateWeightReqInput
,
request
):
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
...
...
@@ -587,13 +540,6 @@ class TokenizerManager:
else
:
return
False
,
"Another update is in progress. Please try again later."
def
abort_request
(
self
,
rid
:
str
):
if
rid
not
in
self
.
rid_to_state
:
return
del
self
.
rid_to_state
[
rid
]
req
=
AbortReq
(
rid
)
self
.
send_to_router
.
send_pyobj
(
req
)
def
create_abort_task
(
self
,
obj
:
GenerateReqInput
):
# Abort the request if the client is disconnected.
async
def
abort_request
():
...
...
@@ -617,6 +563,8 @@ class TokenizerManager:
loop
.
create_task
(
self
.
handle_loop
())
async
def
handle_loop
(
self
):
"""The event loop that handles requests"""
while
True
:
recv_obj
:
Union
[
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
UpdateWeightReqOutput
...
...
@@ -713,11 +661,69 @@ class TokenizerManager:
)
return
top_logprobs
async
def
_get_pixel_values_internal
(
self
,
image_data
,
aspect_ratio
=
None
):
aspect_ratio
=
(
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
None
)
if
aspect_ratio
is
None
else
aspect_ratio
)
grid_pinpoints
=
(
self
.
hf_config
.
image_grid_pinpoints
if
hasattr
(
self
.
hf_config
,
"image_grid_pinpoints"
)
and
"anyres"
in
aspect_ratio
else
None
)
if
isinstance
(
image_data
,
list
)
and
len
(
image_data
)
>
0
:
pixel_values
,
image_hash
,
image_size
=
[],
[],
[]
if
len
(
image_data
)
>
1
:
aspect_ratio
=
"pad"
# LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
for
img_data
in
image_data
:
pixel_v
,
image_h
,
image_s
=
await
self
.
_process_single_image
(
img_data
,
aspect_ratio
,
grid_pinpoints
)
pixel_values
.
append
(
pixel_v
)
image_hash
.
append
(
image_h
)
image_size
.
append
(
image_s
)
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
else
:
pixel_values
,
image_hash
,
image_size
=
await
self
.
_process_single_image
(
image_data
[
0
],
aspect_ratio
,
grid_pinpoints
)
image_hash
=
[
image_hash
]
image_size
=
[
image_size
]
elif
isinstance
(
image_data
,
str
):
pixel_values
,
image_hash
,
image_size
=
await
self
.
_process_single_image
(
image_data
,
aspect_ratio
,
grid_pinpoints
)
image_hash
=
[
image_hash
]
image_size
=
[
image_size
]
else
:
pixel_values
,
image_hash
,
image_size
=
None
,
None
,
None
return
pixel_values
,
image_hash
,
image_size
async
def
_process_single_image
(
self
,
image_data
,
aspect_ratio
,
grid_pinpoints
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
self
.
executor
,
_process_single_image_task
,
image_data
,
aspect_ratio
,
grid_pinpoints
,
)
else
:
return
_process_single_image_task
(
image_data
,
aspect_ratio
,
grid_pinpoints
,
self
.
processor
)
global
global_processor
def
init_global_processor
(
server_args
:
ServerArgs
):
"""Init the global processor for multi modal models."""
global
global_processor
transformers
.
logging
.
set_verbosity_error
()
global_processor
=
get_processor
(
...
...
@@ -727,7 +733,7 @@ def init_global_processor(server_args: ServerArgs):
)
def
get_pixel_values
(
def
_process_single_image_task
(
image_data
,
image_aspect_ratio
=
None
,
image_grid_pinpoints
=
None
,
processor
=
None
):
try
:
...
...
@@ -759,4 +765,4 @@ def get_pixel_values(
pixel_values
=
pixel_values
.
astype
(
np
.
float16
)
return
pixel_values
,
image_hash
,
image
.
size
except
Exception
:
print
(
"Exception in TokenizerManager:
\n
"
+
get_exception_traceback
())
logger
.
error
(
"Exception in TokenizerManager:
\n
"
+
get_exception_traceback
())
python/sglang/srt/managers/tp_worker.py
View file @
90227800
...
...
@@ -56,6 +56,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
configure_logger
,
is_multimodal_model
,
set_random_seed
,
suppress_other_loggers
,
...
...
@@ -145,7 +146,6 @@ class ModelTpServer:
# Print info
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] "
f
"max_total_num_tokens=
{
self
.
max_total_num_tokens
}
, "
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
f
"max_running_requests=
{
self
.
max_running_requests
}
, "
...
...
@@ -284,7 +284,7 @@ class ModelTpServer:
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
logger
.
info
(
f
"
[gpu=
{
self
.
gpu_id
}
]
Decode batch. "
f
"Decode batch. "
f
"#running-req:
{
len
(
self
.
running_batch
.
reqs
)
}
, "
f
"#token:
{
num_used
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
...
...
@@ -443,7 +443,7 @@ class ModelTpServer:
if
num_mixed_running
>
0
:
logger
.
info
(
f
"
[gpu=
{
self
.
gpu_id
}
]
Prefill batch"
f
"Prefill batch"
f
"(mixed #running-req:
{
num_mixed_running
}
). "
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
...
...
@@ -453,7 +453,7 @@ class ModelTpServer:
)
else
:
logger
.
info
(
f
"
[gpu=
{
self
.
gpu_id
}
]
Prefill batch. "
f
"Prefill batch. "
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
...
...
@@ -631,7 +631,7 @@ class ModelTpServer:
self
.
new_token_ratio
=
new_token_ratio
logger
.
info
(
"
d
ecode out of memory happened
,
"
"
D
ecode out of memory happened
.
"
f
"#retracted_reqs:
{
len
(
retracted_reqs
)
}
, "
f
"#new_token_ratio:
{
old_ratio
:.
4
f
}
->
{
self
.
new_token_ratio
:.
4
f
}
"
)
...
...
@@ -848,7 +848,9 @@ def run_tp_server(
nccl_port
:
int
,
model_overide_args
:
dict
,
):
"""Run a tensor parallel server."""
"""Run a tensor parallel model server."""
configure_logger
(
server_args
,
prefix
=
f
" TP
{
tp_rank
}
"
)
try
:
model_server
=
ModelTpServer
(
gpu_id
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
90227800
...
...
@@ -109,7 +109,7 @@ class ModelRunner:
def
init_torch_distributed
(
self
):
# Init torch distributed
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
]
Init nccl begin."
)
logger
.
info
(
"
Init nccl begin."
)
if
not
self
.
server_args
.
enable_p2p_check
:
monkey_patch_vllm_p2p_access_check
(
self
.
gpu_id
)
...
...
@@ -152,8 +152,7 @@ class ModelRunner:
def
load_model
(
self
):
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Load weight begin. "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"Load weight begin. avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
logger
.
info
(
...
...
@@ -208,7 +207,7 @@ class ModelRunner:
)
logger
.
info
(
f
"
[gpu=
{
self
.
gpu_id
}
]
Load weight end. "
f
"Load weight end. "
f
"type=
{
type
(
self
.
model
).
__name__
}
, "
f
"dtype=
{
self
.
dtype
}
, "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
...
...
@@ -224,7 +223,7 @@ class ModelRunner:
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
logger
.
info
(
f
"
[gpu=
{
self
.
gpu_id
}
]
Update weights begin. "
f
"Update weights begin. "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
...
...
@@ -298,7 +297,7 @@ class ModelRunner:
self
.
load_config
=
load_config
self
.
model_config
.
path
=
model_path
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
]
Update weights end."
)
logger
.
info
(
"
Update weights end."
)
return
True
,
"Succeeded to update model weights"
def
profile_max_num_token
(
self
,
total_gpu_memory
:
int
):
...
...
@@ -387,7 +386,7 @@ class ModelRunner:
layer_num
=
self
.
model_config
.
num_hidden_layers
,
)
logger
.
info
(
f
"
[gpu=
{
self
.
gpu_id
}
]
Memory pool end. "
f
"Memory pool end. "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
...
...
@@ -473,9 +472,7 @@ class ModelRunner:
self
.
cuda_graph_runner
=
None
return
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Capture cuda graph begin. This can take up to several minutes."
)
logger
.
info
(
"Capture cuda graph begin. This can take up to several minutes."
)
if
self
.
server_args
.
disable_cuda_graph_padding
:
batch_size_list
=
list
(
range
(
1
,
32
))
+
[
64
,
128
]
...
...
python/sglang/srt/openai_api/adapter.py
View file @
90227800
...
...
@@ -123,7 +123,7 @@ def create_streaming_error_response(
def
load_chat_template_for_openai_api
(
tokenizer_manager
,
chat_template_arg
):
global
chat_template_name
print
(
f
"Use chat template:
{
chat_template_arg
}
"
)
logger
.
info
(
f
"Use chat template:
{
chat_template_arg
}
"
)
if
not
chat_template_exists
(
chat_template_arg
):
if
not
os
.
path
.
exists
(
chat_template_arg
):
raise
RuntimeError
(
...
...
@@ -355,7 +355,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
}
except
Exception
as
e
:
print
(
"error in SGLang:"
,
e
)
logger
.
error
(
"error in SGLang:"
,
e
)
# Update batch status to "failed"
retrieve_batch
=
batch_storage
[
batch_id
]
retrieve_batch
.
status
=
"failed"
...
...
python/sglang/srt/server.py
View file @
90227800
...
...
@@ -74,6 +74,7 @@ from sglang.srt.utils import (
add_api_key_middleware
,
allocate_init_ports
,
assert_pkg_version
,
configure_logger
,
enable_show_time_cost
,
kill_child_process
,
maybe_set_triton_cache_manager
,
...
...
@@ -270,15 +271,12 @@ def launch_server(
"""Launch an HTTP server."""
global
tokenizer_manager
logging
.
basicConfig
(
level
=
getattr
(
logging
,
server_args
.
log_level
.
upper
()),
format
=
"%(message)s"
,
)
configure_logger
(
server_args
)
server_args
.
check_server_args
()
_set_envs_and_config
(
server_args
)
# Allocate ports
# Allocate ports
for inter-process communications
server_args
.
port
,
server_args
.
additional_ports
=
allocate_init_ports
(
server_args
.
port
,
server_args
.
additional_ports
,
...
...
python/sglang/srt/server_args.py
View file @
90227800
...
...
@@ -418,7 +418,7 @@ class ServerArgs:
parser
.
add_argument
(
"--enable-mixed-chunk"
,
action
=
"store_true"
,
help
=
"Enabling mixing prefill and decode in a chunked
batch
."
,
help
=
"Enabling mixing prefill and decode in a
batch when using
chunked
prefill
."
,
)
parser
.
add_argument
(
"--enable-torch-compile"
,
...
...
python/sglang/srt/utils.py
View file @
90227800
...
...
@@ -692,7 +692,7 @@ def monkey_patch_vllm_qvk_linear_loader():
setattr
(
QKVParallelLinear
,
"weight_loader"
,
weight_loader_srt
)
def
add_api_key_middleware
(
app
,
api_key
):
def
add_api_key_middleware
(
app
,
api_key
:
str
):
@
app
.
middleware
(
"http"
)
async
def
authentication
(
request
,
call_next
):
if
request
.
method
==
"OPTIONS"
:
...
...
@@ -704,7 +704,7 @@ def add_api_key_middleware(app, api_key):
return
await
call_next
(
request
)
def
prepare_model
(
model_path
):
def
prepare_model
(
model_path
:
str
):
if
"SGLANG_USE_MODELSCOPE"
in
os
.
environ
:
if
not
os
.
path
.
exists
(
model_path
):
from
modelscope
import
snapshot_download
...
...
@@ -713,7 +713,7 @@ def prepare_model(model_path):
return
model_path
def
prepare_tokenizer
(
tokenizer_path
):
def
prepare_tokenizer
(
tokenizer_path
:
str
):
if
"SGLANG_USE_MODELSCOPE"
in
os
.
environ
:
if
not
os
.
path
.
exists
(
tokenizer_path
):
from
modelscope
import
snapshot_download
...
...
@@ -722,3 +722,13 @@ def prepare_tokenizer(tokenizer_path):
tokenizer_path
,
ignore_patterns
=
[
"*.bin"
,
"*.safetensors"
]
)
return
tokenizer_path
def
configure_logger
(
server_args
,
prefix
:
str
=
""
):
format
=
f
"[%(asctime)s
{
prefix
}
] %(message)s"
logging
.
basicConfig
(
level
=
getattr
(
logging
,
server_args
.
log_level
.
upper
()),
format
=
format
,
datefmt
=
"%H:%M:%S"
,
force
=
True
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment