"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "89d80b58e123acc3074ac3bf6dd77dba6665ea0d"
Unverified Commit 90019096 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

Merge pull request #4 from hpcaitech/fix_gather

fix minor bug in gather
parents e96b76b0 77642096
......@@ -4,7 +4,7 @@ import torch
import torch.distributed as dist
from torch import Tensor
from .core import (get_tensor_model_parallel_group, get_tensor_model_parallel_src_rank,
from .core import (get_tensor_model_parallel_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from .core import ensure_divisibility
......@@ -33,7 +33,7 @@ def _split(tensor: Tensor, dim: int = -1) -> Tensor:
split_size = divide(tensor.shape[dim], get_tensor_model_parallel_world_size())
tensor_list = torch.split(tensor, split_size, dim=dim)
output = tensor_list[get_tensor_model_parallel_src_rank()].contiguous()
output = tensor_list[get_tensor_model_parallel_rank()].contiguous()
return output
......@@ -49,7 +49,7 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=1)
dist.all_gather(list(tensor_list), tensor, group=get_tensor_model_parallel_group(), async_op=False)
else:
tensor_list = [torch.ones_like(tensor) for _ in range(get_tensor_model_parallel_world_size())]
tensor_list = [torch.empty_like(tensor) for _ in range(get_tensor_model_parallel_world_size())]
dist.all_gather(tensor_list, tensor, group=get_tensor_model_parallel_group(), async_op=False)
output = torch.cat(tensor_list, dim=dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment