Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bea2bb9e
Unverified
Commit
bea2bb9e
authored
Aug 20, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 20, 2024
Browse files
Improve multi-node stability (#1171)
parent
cd10654e
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
94 additions
and
76 deletions
+94
-76
python/sglang/launch_server.py
python/sglang/launch_server.py
+8
-1
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+11
-5
python/sglang/srt/managers/controller_multi.py
python/sglang/srt/managers/controller_multi.py
+0
-2
python/sglang/srt/managers/controller_single.py
python/sglang/srt/managers/controller_single.py
+0
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-15
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+10
-6
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+12
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+36
-37
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+5
-6
No files found.
python/sglang/launch_server.py
View file @
bea2bb9e
"""Launch the inference server."""
"""Launch the inference server."""
import
argparse
import
argparse
import
os
from
sglang.srt.server
import
launch_server
from
sglang.srt.server
import
launch_server
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
kill_child_process
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
@@ -11,4 +13,9 @@ if __name__ == "__main__":
...
@@ -11,4 +13,9 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
server_args
=
ServerArgs
.
from_cli_args
(
args
)
server_args
=
ServerArgs
.
from_cli_args
(
args
)
launch_server
(
server_args
)
try
:
launch_server
(
server_args
)
except
Exception
as
e
:
raise
e
finally
:
kill_child_process
(
os
.
getpid
(),
including_parent
=
False
)
python/sglang/srt/hf_transformers_utils.py
View file @
bea2bb9e
...
@@ -233,6 +233,8 @@ class TiktokenTokenizer:
...
@@ -233,6 +233,8 @@ class TiktokenTokenizer:
}
}
assert
tok_dict
[
"word_split"
]
==
"V1"
assert
tok_dict
[
"word_split"
]
==
"V1"
default_allowed_special
=
None
kwargs
=
{
kwargs
=
{
"name"
:
name
,
"name"
:
name
,
"pat_str"
:
tok_dict
.
get
(
"pat_str"
,
PAT_STR_B
),
"pat_str"
:
tok_dict
.
get
(
"pat_str"
,
PAT_STR_B
),
...
@@ -246,14 +248,18 @@ class TiktokenTokenizer:
...
@@ -246,14 +248,18 @@ class TiktokenTokenizer:
for
bytes_list
in
tok_dict
[
"default_allowed_special"
]
for
bytes_list
in
tok_dict
[
"default_allowed_special"
]
]
]
)
)
else
:
default_allowed_special
=
None
if
"vocab_size"
in
tok_dict
:
if
"vocab_size"
in
tok_dict
:
kwargs
[
"explicit_n_vocab"
]
=
tok_dict
[
"vocab_size"
]
kwargs
[
"explicit_n_vocab"
]
=
tok_dict
[
"vocab_size"
]
PAD
=
"<|pad|>"
EOS
=
"<|eos|>"
SEP
=
"<|separator|>"
DEFAULT_CONTROL_TOKENS
=
{
"pad"
:
PAD
,
"sep"
:
EOS
,
"eos"
:
SEP
}
tokenizer
=
tiktoken
.
Encoding
(
**
kwargs
)
tokenizer
=
tiktoken
.
Encoding
(
**
kwargs
)
tokenizer
.
_default_allowed_special
=
default_allowed_special
or
set
()
tokenizer
.
_default_allowed_special
=
default_allowed_special
or
set
()
tokenizer
.
_
default_allowed_special
|=
{
"<|separator|>"
}
tokenizer
.
_
control_tokens
=
DEFAULT_CONTROL_TOKENS
def
encode_patched
(
def
encode_patched
(
self
,
self
,
...
@@ -270,14 +276,14 @@ class TiktokenTokenizer:
...
@@ -270,14 +276,14 @@ class TiktokenTokenizer:
self
,
self
,
text
,
text
,
allowed_special
=
allowed_special
,
allowed_special
=
allowed_special
,
disallowed_special
=
disallowed_special
,
disallowed_special
=
()
,
)
)
tokenizer
.
encode
=
functools
.
partial
(
encode_patched
,
tokenizer
)
tokenizer
.
encode
=
functools
.
partial
(
encode_patched
,
tokenizer
)
# Convert to HF interface
# Convert to HF interface
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
eos_token_id
=
tokenizer
.
_special_tokens
[
"<|eos|>"
]
self
.
eos_token_id
=
tokenizer
.
_special_tokens
[
EOS
]
self
.
vocab_size
=
tokenizer
.
n_vocab
self
.
vocab_size
=
tokenizer
.
n_vocab
self
.
chat_template
=
Template
(
self
.
chat_template
=
Template
(
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>
\n\n
' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>
\n\n
' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
...
...
python/sglang/srt/managers/controller_multi.py
View file @
bea2bb9e
...
@@ -212,6 +212,4 @@ def start_controller_process(
...
@@ -212,6 +212,4 @@ def start_controller_process(
except
Exception
:
except
Exception
:
logger
.
error
(
"Exception in ControllerMulti:
\n
"
+
get_exception_traceback
())
logger
.
error
(
"Exception in ControllerMulti:
\n
"
+
get_exception_traceback
())
finally
:
finally
:
for
w
in
controller
.
workers
:
os
.
kill
(
w
.
proc
.
pid
,
9
)
kill_parent_process
()
kill_parent_process
()
python/sglang/srt/managers/controller_single.py
View file @
bea2bb9e
...
@@ -167,6 +167,4 @@ def start_controller_process(
...
@@ -167,6 +167,4 @@ def start_controller_process(
except
Exception
:
except
Exception
:
logger
.
error
(
"Exception in ControllerSingle:
\n
"
+
get_exception_traceback
())
logger
.
error
(
"Exception in ControllerSingle:
\n
"
+
get_exception_traceback
())
finally
:
finally
:
for
t
in
controller
.
tp_procs
:
os
.
kill
(
t
.
pid
,
9
)
kill_parent_process
()
kill_parent_process
()
python/sglang/srt/managers/schedule_batch.py
View file @
bea2bb9e
...
@@ -16,7 +16,6 @@ limitations under the License.
...
@@ -16,7 +16,6 @@ limitations under the License.
"""Meta data for requests and batches"""
"""Meta data for requests and batches"""
import
logging
import
logging
import
warnings
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
...
@@ -270,7 +269,7 @@ class Req:
...
@@ -270,7 +269,7 @@ class Req:
if
all_ids
[
prompt_tokens
-
1
]
!=
self
.
origin_input_ids_unpadded
[
-
1
]:
if
all_ids
[
prompt_tokens
-
1
]
!=
self
.
origin_input_ids_unpadded
[
-
1
]:
# TODO(lsyin): fix token fusion
# TODO(lsyin): fix token fusion
logg
ing
.
warning
(
logg
er
.
warning
(
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
)
)
return
False
return
False
...
@@ -753,7 +752,7 @@ class ScheduleBatch:
...
@@ -753,7 +752,7 @@ class ScheduleBatch:
)
)
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
def
sample
(
self
,
logits
:
torch
.
Tensor
,
is_multi_node_tp
=
False
):
def
sample
(
self
,
logits
:
torch
.
Tensor
):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits
# Post process logits
logits
=
logits
.
contiguous
()
logits
=
logits
.
contiguous
()
...
@@ -791,7 +790,7 @@ class ScheduleBatch:
...
@@ -791,7 +790,7 @@ class ScheduleBatch:
)
)
if
not
torch
.
all
(
success
):
if
not
torch
.
all
(
success
):
logg
ing
.
warning
(
"Sampling failed
, f
allback to top_k=1 strategy"
)
logg
er
.
warning
(
f
"Sampling failed
. F
allback to top_k=1 strategy
.
{
logits
=
}
"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
batch_next_token_ids
=
torch
.
where
(
batch_next_token_ids
=
torch
.
where
(
...
@@ -808,16 +807,6 @@ class ScheduleBatch:
...
@@ -808,16 +807,6 @@ class ScheduleBatch:
self
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch_next_token_ids
)
self
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch_next_token_ids
)
if
is_multi_node_tp
:
# If the tensor parallelism spans across multiple nodes, there is some indeterminism
# that can cause the TP workers to generate different tokens, so we need to
# sync here
torch
.
distributed
.
all_reduce
(
batch_next_token_ids
,
op
=
dist
.
ReduceOp
.
MIN
,
group
=
get_tensor_model_parallel_group
().
device_group
,
)
return
batch_next_token_ids
return
batch_next_token_ids
...
@@ -835,7 +824,8 @@ def top_k_top_p_sampling_from_probs_torch(
...
@@ -835,7 +824,8 @@ def top_k_top_p_sampling_from_probs_torch(
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
try
:
try
:
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
except
RuntimeError
:
except
RuntimeError
as
e
:
logger
.
warning
(
f
"Sampling error:
{
e
}
"
)
batch_next_token_ids
=
torch
.
zeros
(
batch_next_token_ids
=
torch
.
zeros
(
(
probs_sort
.
shape
[
0
],),
dtype
=
torch
.
int32
,
device
=
probs
.
device
(
probs_sort
.
shape
[
0
],),
dtype
=
torch
.
int32
,
device
=
probs
.
device
)
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
bea2bb9e
...
@@ -133,6 +133,13 @@ class ModelTpServer:
...
@@ -133,6 +133,13 @@ class ModelTpServer:
self
.
model_config
.
context_len
-
1
,
self
.
model_config
.
context_len
-
1
,
self
.
max_total_num_tokens
-
1
,
self
.
max_total_num_tokens
-
1
,
)
)
# Sync random seed
server_args
.
random_seed
=
broadcast_recv_input
(
[
server_args
.
random_seed
],
self
.
tp_rank
,
self
.
model_runner
.
tp_group
.
cpu_group
,
)[
0
]
set_random_seed
(
server_args
.
random_seed
)
set_random_seed
(
server_args
.
random_seed
)
# Print info
# Print info
...
@@ -474,9 +481,7 @@ class ModelTpServer:
...
@@ -474,9 +481,7 @@ class ModelTpServer:
# Forward and sample the next tokens
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
if
batch
.
extend_num_tokens
!=
0
:
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
=
batch
.
sample
(
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
output
.
next_token_logits
,
self
.
model_runner
.
is_multi_node_tp
)
# Move logprobs to cpu
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
if
output
.
next_token_logprobs
is
not
None
:
...
@@ -636,9 +641,7 @@ class ModelTpServer:
...
@@ -636,9 +641,7 @@ class ModelTpServer:
# Forward and sample the next tokens
# Forward and sample the next tokens
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
=
batch
.
sample
(
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
output
.
next_token_logits
,
self
.
model_runner
.
is_multi_node_tp
)
# Move logprobs to cpu
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
if
output
.
next_token_logprobs
is
not
None
:
...
@@ -879,6 +882,7 @@ def broadcast_recv_input(
...
@@ -879,6 +882,7 @@ def broadcast_recv_input(
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
dist
.
broadcast
(
tensor_data
,
src
=
0
,
group
=
dist_group
)
dist
.
broadcast
(
tensor_data
,
src
=
0
,
group
=
dist_group
)
return
data
else
:
else
:
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
bea2bb9e
...
@@ -84,13 +84,20 @@ def set_torch_compile_config():
...
@@ -84,13 +84,20 @@ def set_torch_compile_config():
class
CudaGraphRunner
:
class
CudaGraphRunner
:
def
__init__
(
self
,
model_runner
,
max_batch_size_to_capture
,
use_torch_compile
):
def
__init__
(
self
,
model_runner
,
max_batch_size_to_capture
:
int
,
use_torch_compile
:
bool
,
disable_padding
:
bool
,
):
self
.
model_runner
=
model_runner
self
.
model_runner
=
model_runner
self
.
graphs
=
{}
self
.
graphs
=
{}
self
.
input_buffers
=
{}
self
.
input_buffers
=
{}
self
.
output_buffers
=
{}
self
.
output_buffers
=
{}
self
.
flashinfer_handlers
=
{}
self
.
flashinfer_handlers
=
{}
self
.
graph_memory_pool
=
None
self
.
graph_memory_pool
=
None
self
.
disable_padding
=
disable_padding
# Common inputs
# Common inputs
self
.
max_bs
=
max_batch_size_to_capture
self
.
max_bs
=
max_batch_size_to_capture
...
@@ -142,7 +149,10 @@ class CudaGraphRunner:
...
@@ -142,7 +149,10 @@ class CudaGraphRunner:
set_torch_compile_config
()
set_torch_compile_config
()
def
can_run
(
self
,
batch_size
):
def
can_run
(
self
,
batch_size
):
return
batch_size
<=
self
.
max_bs
if
self
.
disable_padding
:
return
batch_size
in
self
.
graphs
else
:
return
batch_size
<=
self
.
max_bs
def
capture
(
self
,
batch_size_list
):
def
capture
(
self
,
batch_size_list
):
self
.
batch_size_list
=
batch_size_list
self
.
batch_size_list
=
batch_size_list
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
bea2bb9e
...
@@ -465,6 +465,7 @@ class ModelRunner:
...
@@ -465,6 +465,7 @@ class ModelRunner:
self
,
self
,
max_batch_size_to_capture
=
max
(
batch_size_list
),
max_batch_size_to_capture
=
max
(
batch_size_list
),
use_torch_compile
=
self
.
server_args
.
enable_torch_compile
,
use_torch_compile
=
self
.
server_args
.
enable_torch_compile
,
disable_padding
=
self
.
server_args
.
disable_cuda_graph_padding
,
)
)
try
:
try
:
self
.
cuda_graph_runner
.
capture
(
batch_size_list
)
self
.
cuda_graph_runner
.
capture
(
batch_size_list
)
...
...
python/sglang/srt/server.py
View file @
bea2bb9e
...
@@ -24,7 +24,6 @@ import json
...
@@ -24,7 +24,6 @@ import json
import
logging
import
logging
import
multiprocessing
as
mp
import
multiprocessing
as
mp
import
os
import
os
import
sys
import
threading
import
threading
import
time
import
time
from
http
import
HTTPStatus
from
http
import
HTTPStatus
...
@@ -301,27 +300,29 @@ def launch_server(
...
@@ -301,27 +300,29 @@ def launch_server(
server_args
.
tokenizer_path
=
prepare_tokenizer
(
server_args
.
tokenizer_path
)
server_args
.
tokenizer_path
=
prepare_tokenizer
(
server_args
.
tokenizer_path
)
# Launch processes for multi-node tensor parallelism
# Launch processes for multi-node tensor parallelism
if
server_args
.
nnodes
>
1
:
if
server_args
.
nnodes
>
1
and
server_args
.
node_rank
!=
0
:
if
server_args
.
node_rank
!=
0
:
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
gpu_ids
=
[
i
for
_
in
range
(
server_args
.
nnodes
)
for
i
in
range
(
tp_size_local
)]
gpu_ids
=
[
tp_rank_range
=
list
(
i
for
_
in
range
(
server_args
.
nnodes
)
for
i
in
range
(
tp_size_local
)
range
(
]
server_args
.
node_rank
*
tp_size_local
,
tp_rank_range
=
list
(
(
server_args
.
node_rank
+
1
)
*
tp_size_local
,
range
(
server_args
.
node_rank
*
tp_size_local
,
(
server_args
.
node_rank
+
1
)
*
tp_size_local
,
)
)
)
procs
=
launch_tp_servers
(
)
gpu_ids
,
procs
=
launch_tp_servers
(
tp_rank_range
,
gpu_ids
,
server_args
,
tp_rank_range
,
ports
[
3
],
server_args
,
model_overide_args
,
ports
[
3
],
)
model_overide_args
,
while
True
:
)
pass
try
:
for
p
in
procs
:
p
.
join
()
finally
:
kill_child_process
(
os
.
getpid
(),
including_parent
=
False
)
return
# Launch processes
# Launch processes
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
,
model_overide_args
)
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
,
model_overide_args
)
...
@@ -356,15 +357,11 @@ def launch_server(
...
@@ -356,15 +357,11 @@ def launch_server(
if
controller_init_state
!=
"init ok"
or
detoken_init_state
!=
"init ok"
:
if
controller_init_state
!=
"init ok"
or
detoken_init_state
!=
"init ok"
:
proc_controller
.
kill
()
proc_controller
.
kill
()
proc_detoken
.
kill
()
proc_detoken
.
kill
()
print
(
raise
RuntimeError
(
f
"Initialization failed. controller_init_state:
{
controller_init_state
}
"
,
"Initialization failed. "
flush
=
True
,
f
"controller_init_state:
{
controller_init_state
}
, "
f
"detoken_init_state:
{
detoken_init_state
}
"
)
)
print
(
f
"Initialization failed. detoken_init_state:
{
detoken_init_state
}
"
,
flush
=
True
,
)
sys
.
exit
(
1
)
assert
proc_controller
.
is_alive
()
and
proc_detoken
.
is_alive
()
assert
proc_controller
.
is_alive
()
and
proc_detoken
.
is_alive
()
# Add api key authorization
# Add api key authorization
...
@@ -373,12 +370,12 @@ def launch_server(
...
@@ -373,12 +370,12 @@ def launch_server(
# Send a warmup request
# Send a warmup request
t
=
threading
.
Thread
(
t
=
threading
.
Thread
(
target
=
_wait_and_warmup
,
args
=
(
server_args
,
pipe_finish_writer
)
target
=
_wait_and_warmup
,
args
=
(
server_args
,
pipe_finish_writer
,
os
.
getpid
()
)
)
)
t
.
start
()
t
.
start
()
# Listen for requests
try
:
try
:
# Listen for requests
uvicorn
.
run
(
uvicorn
.
run
(
app
,
app
,
host
=
server_args
.
host
,
host
=
server_args
.
host
,
...
@@ -426,7 +423,7 @@ def _set_envs_and_config(server_args: ServerArgs):
...
@@ -426,7 +423,7 @@ def _set_envs_and_config(server_args: ServerArgs):
)
)
def
_wait_and_warmup
(
server_args
,
pipe_finish_writer
):
def
_wait_and_warmup
(
server_args
,
pipe_finish_writer
,
pid
):
headers
=
{}
headers
=
{}
url
=
server_args
.
url
()
url
=
server_args
.
url
()
if
server_args
.
api_key
:
if
server_args
.
api_key
:
...
@@ -449,8 +446,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
...
@@ -449,8 +446,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if
not
success
:
if
not
success
:
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
last_traceback
)
pipe_finish_writer
.
send
(
last_traceback
)
print
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
,
flush
=
True
)
logger
.
error
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
)
sys
.
exit
(
1
)
kill_child_process
(
pid
,
including_parent
=
False
)
return
# Send a warmup request
# Send a warmup request
request_name
=
"/generate"
if
model_info
[
"is_generation"
]
else
"/encode"
request_name
=
"/generate"
if
model_info
[
"is_generation"
]
else
"/encode"
...
@@ -475,12 +473,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
...
@@ -475,12 +473,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
timeout
=
600
,
timeout
=
600
,
)
)
assert
res
.
status_code
==
200
,
f
"
{
res
}
"
assert
res
.
status_code
==
200
,
f
"
{
res
}
"
except
Exception
as
e
:
except
Exception
:
last_traceback
=
get_exception_traceback
()
last_traceback
=
get_exception_traceback
()
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
last_traceback
)
pipe_finish_writer
.
send
(
last_traceback
)
print
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
,
flush
=
True
)
logger
.
error
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
)
sys
.
exit
(
1
)
kill_child_process
(
pid
,
including_parent
=
False
)
return
logger
.
info
(
"The server is fired up and ready to roll!"
)
logger
.
info
(
"The server is fired up and ready to roll!"
)
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
...
...
python/sglang/srt/server_args.py
View file @
bea2bb9e
...
@@ -79,6 +79,7 @@ class ServerArgs:
...
@@ -79,6 +79,7 @@ class ServerArgs:
disable_radix_cache
:
bool
=
False
disable_radix_cache
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_cuda_graph
:
bool
=
False
disable_cuda_graph
:
bool
=
False
disable_cuda_graph_padding
:
bool
=
False
disable_disk_cache
:
bool
=
False
disable_disk_cache
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_torch_compile
:
bool
=
False
...
@@ -393,6 +394,11 @@ class ServerArgs:
...
@@ -393,6 +394,11 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Disable cuda graph."
,
help
=
"Disable cuda graph."
,
)
)
parser
.
add_argument
(
"--disable-cuda-graph-padding"
,
action
=
"store_true"
,
help
=
"Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--disable-disk-cache"
,
"--disable-disk-cache"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
python/sglang/srt/utils.py
View file @
bea2bb9e
...
@@ -369,14 +369,11 @@ def kill_parent_process():
...
@@ -369,14 +369,11 @@ def kill_parent_process():
"""Kill the parent process and all children of the parent process."""
"""Kill the parent process and all children of the parent process."""
current_process
=
psutil
.
Process
()
current_process
=
psutil
.
Process
()
parent_process
=
current_process
.
parent
()
parent_process
=
current_process
.
parent
()
children
=
parent_process
.
children
(
recursive
=
True
)
kill_child_process
(
parent_process
.
pid
,
skip_pid
=
current_process
.
pid
)
for
child
in
children
:
if
child
.
pid
!=
current_process
.
pid
:
os
.
kill
(
child
.
pid
,
9
)
os
.
kill
(
parent_process
.
pid
,
9
)
def
kill_child_process
(
pid
,
including_parent
=
True
):
def
kill_child_process
(
pid
,
including_parent
=
True
,
skip_pid
=
None
):
"""Kill the process and all its children process."""
try
:
try
:
parent
=
psutil
.
Process
(
pid
)
parent
=
psutil
.
Process
(
pid
)
except
psutil
.
NoSuchProcess
:
except
psutil
.
NoSuchProcess
:
...
@@ -384,6 +381,8 @@ def kill_child_process(pid, including_parent=True):
...
@@ -384,6 +381,8 @@ def kill_child_process(pid, including_parent=True):
children
=
parent
.
children
(
recursive
=
True
)
children
=
parent
.
children
(
recursive
=
True
)
for
child
in
children
:
for
child
in
children
:
if
child
.
pid
==
skip_pid
:
continue
try
:
try
:
child
.
kill
()
child
.
kill
()
except
psutil
.
NoSuchProcess
:
except
psutil
.
NoSuchProcess
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment