Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c0bb9eb3
Unverified
Commit
c0bb9eb3
authored
Feb 25, 2025
by
Shenggui Li
Committed by
GitHub
Feb 25, 2025
Browse files
[improve] made timeout configurable (#3803)
parent
7036d6fc
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
26 additions
and
1 deletion
+26
-1
docs/references/deepseek.md
docs/references/deepseek.md
+6
-0
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+9
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+3
-1
No files found.
docs/references/deepseek.md
View file @
c0bb9eb3
...
...
@@ -81,3 +81,9 @@ Overall, with these optimizations, we have achieved up to a 7x acceleration in o
-
**Weight**
: Per-128x128-block quantization for better numerical stability.
**Usage**
: turn on by default for DeepSeek V3 models.
## FAQ
**Question**
: What should I do if model loading takes too long and NCCL timeout occurs?
Answer: You can try to add
`--dist-timeout 3600`
when launching the model, this allows for 1-hour timeout.i
python/sglang/srt/distributed/parallel_state.py
View file @
c0bb9eb3
...
...
@@ -30,6 +30,7 @@ import weakref
from
collections
import
namedtuple
from
contextlib
import
contextmanager
,
nullcontext
from
dataclasses
import
dataclass
from
datetime
import
timedelta
from
multiprocessing
import
shared_memory
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
unittest.mock
import
patch
...
...
@@ -960,6 +961,7 @@ def init_distributed_environment(
distributed_init_method
:
str
=
"env://"
,
local_rank
:
int
=
-
1
,
backend
:
str
=
"nccl"
,
timeout
:
Optional
[
int
]
=
None
,
):
logger
.
debug
(
"world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s"
,
...
...
@@ -974,13 +976,20 @@ def init_distributed_environment(
"distributed_init_method must be provided when initializing "
"distributed environment"
)
if
timeout
is
not
None
:
assert
isinstance
(
timeout
,
(
int
)),
"timeout must be a number"
assert
timeout
>
0
,
"timeout must be positive"
timeout
=
timedelta
(
seconds
=
timeout
)
# this backend is used for WORLD
torch
.
distributed
.
init_process_group
(
backend
=
backend
,
init_method
=
distributed_init_method
,
world_size
=
world_size
,
rank
=
rank
,
timeout
=
timeout
,
)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
c0bb9eb3
...
...
@@ -259,6 +259,7 @@ class ModelRunner:
rank
=
self
.
tp_rank
,
local_rank
=
self
.
gpu_id
,
distributed_init_method
=
dist_init_method
,
timeout
=
self
.
server_args
.
dist_timeout
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
initialize_dp_attention
(
...
...
python/sglang/srt/server_args.py
View file @
c0bb9eb3
...
...
@@ -79,6 +79,7 @@ class ServerArgs:
random_seed
:
Optional
[
int
]
=
None
constrained_json_whitespace_pattern
:
Optional
[
str
]
=
None
watchdog_timeout
:
float
=
300
dist_timeout
:
Optional
[
int
]
=
None
# timeout for torch.distributed
download_dir
:
Optional
[
str
]
=
None
base_gpu_id
:
int
=
0
...
...
@@ -534,6 +535,12 @@ class ServerArgs:
default
=
ServerArgs
.
watchdog_timeout
,
help
=
"Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging."
,
)
parser
.
add_argument
(
"--dist-timeout"
,
type
=
int
,
default
=
ServerArgs
.
dist_timeout
,
help
=
"Set timeout for torch.distributed initialization."
,
)
parser
.
add_argument
(
"--download-dir"
,
type
=
str
,
...
...
python/sglang/test/test_utils.py
View file @
c0bb9eb3
...
...
@@ -503,7 +503,9 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
ret_code
=
run_with_timeout
(
run_one_file
,
args
=
(
filename
,),
timeout
=
timeout_per_file
)
assert
ret_code
==
0
assert
(
ret_code
==
0
),
f
"expected return code 0, but
{
filename
}
returned
{
ret_code
}
"
except
TimeoutError
:
kill_process_tree
(
process
.
pid
)
time
.
sleep
(
5
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment