Unverified Commit 8b84e69f authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix tp token sync for dp attention (#3062)

parent 5de50653
......@@ -6,6 +6,7 @@ import torch.distributed as dist
from torch import nn
from sglang.srt.distributed import get_tensor_model_parallel_group
from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
......@@ -33,6 +34,10 @@ class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
self.tp_sync_group = get_tensor_model_parallel_group().device_group
if global_server_args_dict["enable_dp_attention"]:
self.tp_sync_group = get_attention_tp_group().device_group
def forward(
self,
......@@ -140,7 +145,7 @@ class Sampler(nn.Module):
torch.distributed.all_reduce(
batch_next_token_ids,
op=dist.ReduceOp.MIN,
group=get_tensor_model_parallel_group().device_group,
group=self.tp_sync_group,
)
return batch_next_token_ids.to(torch.int32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment