"vscode:/vscode.git/clone" did not exist on "b04cf65b2b0e3534ebe8de750fd1493df34f8960"
Unverified Commit 022614d2 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add some flags to allow sync token ids across TP ranks (#3060)

parent b8ab989f
...@@ -2,12 +2,18 @@ import logging ...@@ -2,12 +2,18 @@ import logging
from typing import List from typing import List
import torch import torch
import torch.distributed as dist
from torch import nn from torch import nn
from sglang.srt.distributed import get_tensor_model_parallel_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available from sglang.srt.utils import (
crash_on_warnings,
get_bool_env_var,
is_flashinfer_available,
)
if is_flashinfer_available(): if is_flashinfer_available():
from flashinfer.sampling import ( from flashinfer.sampling import (
...@@ -20,6 +26,8 @@ if is_flashinfer_available(): ...@@ -20,6 +26,8 @@ if is_flashinfer_available():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
class Sampler(nn.Module): class Sampler(nn.Module):
def __init__(self): def __init__(self):
...@@ -121,6 +129,20 @@ class Sampler(nn.Module): ...@@ -121,6 +129,20 @@ class Sampler(nn.Module):
batch_next_token_ids, batch_next_token_ids,
] ]
if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars:
# For performance reasons, SGLang does not sync the final token IDs across TP ranks by default.
# This saves one all-reduce, but the correctness of this approach depends on the determinism of several operators:
# the last all-reduce, the last lm_head matmul, and all sampling kernels.
# These kernels are deterministic in most cases, but there are some rare instances where they are not deterministic.
# In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized.
# When using xgrammar, this becomes more likely so we also do the sync when grammar is used.
torch.distributed.all_reduce(
batch_next_token_ids,
op=dist.ReduceOp.MIN,
group=get_tensor_model_parallel_group().device_group,
)
return batch_next_token_ids.to(torch.int32) return batch_next_token_ids.to(torch.int32)
def _apply_custom_logit_processor( def _apply_custom_logit_processor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment