distributed_tensor.py 899 Bytes
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from nanotron import distributed as dist
from nanotron.distributed import ProcessGroup, get_global_rank


def assert_tensor_equal_over_group(tensor: torch.Tensor, group: ProcessGroup, assert_: bool = True) -> bool:
    """We assume that tensors are already of correct size."""
    reference_rank = 0
    if dist.get_rank(group) == reference_rank:
        reference_tensor = tensor
    else:
        reference_tensor = torch.empty_like(tensor)
    dist.broadcast(
        reference_tensor,
        src=get_global_rank(group=group, group_rank=reference_rank),
        group=group,
    )
    if assert_:
        torch.testing.assert_close(tensor, reference_tensor, atol=0, rtol=0)
    else:
        result = torch.allclose(tensor, reference_tensor, atol=0.0, rtol=0.0)
        results = [0] * group.size()
        dist.all_gather_object(results, result, group)
        return all(results)