Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
022614d2
Unverified
Commit
022614d2
authored
Jan 22, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 22, 2025
Browse files
Add some flags to allow sync token ids across TP ranks (#3060)
parent
b8ab989f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
1 deletion
+23
-1
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+23
-1
No files found.
python/sglang/srt/layers/sampler.py
View file @
022614d2
...
...
@@ -2,12 +2,18 @@ import logging
from
typing
import
List
import
torch
import
torch.distributed
as
dist
from
torch
import
nn
from
sglang.srt.distributed
import
get_tensor_model_parallel_group
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.utils
import
crash_on_warnings
,
is_flashinfer_available
from
sglang.srt.utils
import
(
crash_on_warnings
,
get_bool_env_var
,
is_flashinfer_available
,
)
if
is_flashinfer_available
():
from
flashinfer.sampling
import
(
...
...
@@ -20,6 +26,8 @@ if is_flashinfer_available():
logger
=
logging
.
getLogger
(
__name__
)
SYNC_TOKEN_IDS_ACROSS_TP
=
get_bool_env_var
(
"SYNC_TOKEN_IDS_ACROSS_TP"
)
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
):
...
...
@@ -121,6 +129,20 @@ class Sampler(nn.Module):
batch_next_token_ids
,
]
if
SYNC_TOKEN_IDS_ACROSS_TP
or
sampling_info
.
grammars
:
# For performance reasons, SGLang does not sync the final token IDs across TP ranks by default.
# This saves one all-reduce, but the correctness of this approach depends on the determinism of several operators:
# the last all-reduce, the last lm_head matmul, and all sampling kernels.
# These kernels are deterministic in most cases, but there are some rare instances where they are not deterministic.
# In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized.
# When using xgrammar, this becomes more likely so we also do the sync when grammar is used.
torch
.
distributed
.
all_reduce
(
batch_next_token_ids
,
op
=
dist
.
ReduceOp
.
MIN
,
group
=
get_tensor_model_parallel_group
().
device_group
,
)
return
batch_next_token_ids
.
to
(
torch
.
int32
)
def
_apply_custom_logit_processor
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment