Unverified Commit 8b84e69f authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix tp token sync for dp attention (#3062)

parent 5de50653
...@@ -6,6 +6,7 @@ import torch.distributed as dist ...@@ -6,6 +6,7 @@ import torch.distributed as dist
from torch import nn from torch import nn
from sglang.srt.distributed import get_tensor_model_parallel_group from sglang.srt.distributed import get_tensor_model_parallel_group
from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
...@@ -33,6 +34,10 @@ class Sampler(nn.Module): ...@@ -33,6 +34,10 @@ class Sampler(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"] self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
self.tp_sync_group = get_tensor_model_parallel_group().device_group
if global_server_args_dict["enable_dp_attention"]:
self.tp_sync_group = get_attention_tp_group().device_group
def forward( def forward(
self, self,
...@@ -140,7 +145,7 @@ class Sampler(nn.Module): ...@@ -140,7 +145,7 @@ class Sampler(nn.Module):
torch.distributed.all_reduce( torch.distributed.all_reduce(
batch_next_token_ids, batch_next_token_ids,
op=dist.ReduceOp.MIN, op=dist.ReduceOp.MIN,
group=get_tensor_model_parallel_group().device_group, group=self.tp_sync_group,
) )
return batch_next_token_ids.to(torch.int32) return batch_next_token_ids.to(torch.int32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment