Unverified Commit 26c04065 authored by Yang Fan's avatar Yang Fan Committed by GitHub
Browse files

[Bugfix] Fix distributed bug in Qwen2.5-VL & Qwen2.5-Omni (#16907)

parent 4c41278b
......@@ -198,9 +198,8 @@ class Qwen2_5_VisionMLP(nn.Module):
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
"""All-gather the input tensor interleavely across model parallel group."""
import torch.distributed as dist
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
dist.all_gather(gathered_tensors, local_tensor)
parallel_state.get_tp_group().all_gather(gathered_tensors, local_tensor)
gathered_tensors_split = [
torch.split(tensor, hidden_size // tp_size, -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment