ring_comm.py 1.64 KB
Newer Older
Xinchi Huang's avatar
Xinchi Huang committed
1
from typing import Optional
PengGao's avatar
PengGao committed
2

Xinchi Huang's avatar
Xinchi Huang committed
3
4
import torch
import torch.distributed as dist
PengGao's avatar
PengGao committed
5
from loguru import logger
Xinchi Huang's avatar
Xinchi Huang committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25


class RingComm:
    def __init__(self, process_group: dist.ProcessGroup = None):
        self._process_group = process_group
        self._ops = []
        self.rank = dist.get_rank(self._process_group)
        self.world_size = dist.get_world_size(self._process_group)
        self._reqs = None

        self.send_rank = (self.rank + 1) % self.world_size
        self.recv_rank = (self.rank - 1) % self.world_size

        if process_group is not None:
            self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
            self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)

    def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
        if recv_tensor is None:
            res = torch.empty_like(to_send)
root's avatar
root committed
26
            # logger.info(f"send_recv: empty_like {to_send.shape}")
Xinchi Huang's avatar
Xinchi Huang committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
        else:
            res = recv_tensor

        send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group)
        recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
        self._ops.append(send_op)
        self._ops.append(recv_op)
        return res

    def commit(self):
        if self._reqs is not None:
            raise RuntimeError("commit called twice")
        self._reqs = dist.batch_isend_irecv(self._ops)

    def wait(self):
        if self._reqs is None:
            raise RuntimeError("wait called before commit")
        for req in self._reqs:
            req.wait()
        self._reqs = None
        self._ops = []