test_p2p.py 3.87 KB
Newer Older
chenzk's avatar
v1.0.3  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import contextlib

import pytest
import torch
from helpers.exception import assert_fail_with
from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use
from nanotron import distributed as dist
from nanotron.parallel import ParallelContext
from nanotron.parallel.pipeline_parallel.p2p import P2P


@pytest.mark.skipif(available_gpus() < 2, reason="Testing test_ddp_with_afab requires at least 2 gpus")
@pytest.mark.parametrize("send_contiguous", [True, False])
@pytest.mark.parametrize("full", [True, False])
@rerun_if_address_is_in_use()
def test_check_send_recv_tensor(send_contiguous: bool, full: bool):
    init_distributed(tp=1, dp=1, pp=2)(_test_check_send_recv_tensor)(send_contiguous=send_contiguous, full=full)


def _test_check_send_recv_tensor(parallel_context: ParallelContext, send_contiguous: bool, full: bool):
    p2p = P2P(pg=parallel_context.pp_pg, device=torch.device("cuda"))
    if dist.get_rank(p2p.pg) == 0:
        tensor_to_send = torch.randn(3, 5, dtype=torch.float, device=torch.device("cuda"))
        if send_contiguous is True:
            assert tensor_to_send.is_contiguous()
        else:
            tensor_to_send = tensor_to_send.transpose(0, 1)
            assert not tensor_to_send.is_contiguous()

        # `full` defines if we take a non trivial slice of the tensor
        if full is False:
            tensor_to_send = tensor_to_send[1:3]

    if send_contiguous is False and full is False:
        # This is supposed to return a ValueError mentioning that you should have sent a smaller model by running `contiguous` before.
        send_first_context = assert_fail_with(
            AssertionError,
            error_msg="Expect storage_size to be smaller than tensor size. It might not be true, when you use slicing for example though. We probably don't want to support it in our P2P system",
        )
        fail_at_first_send = True
    else:
        send_first_context = contextlib.nullcontext()
        fail_at_first_send = False

    # Send tensor back and forth through p2p protocol and check that we get the same thing.
    if dist.get_rank(p2p.pg) == 0:
        with send_first_context:
            handles = p2p.isend_tensors([tensor_to_send], to_rank=1)
        if fail_at_first_send is True:
            # We early return if we caught an error
            return
        for handle in handles:
            handle.wait()
        tensor_travelled_back_and_forth = p2p.recv_tensors(1, from_rank=1)[0]
        torch.testing.assert_close(tensor_to_send, tensor_travelled_back_and_forth, atol=0, rtol=0)
    elif dist.get_rank(p2p.pg) == 1:
        #  Instead of letting first rank hang since sending won't be possible, we early return
        tensors, handles = p2p.irecv_tensors(1, from_rank=0)
        if fail_at_first_send is True:
            return
        for handle in handles:
            handle.wait()
        tensor_to_recv = tensors[0]
        p2p.send_tensors([tensor_to_recv], to_rank=0)
    else:
        raise ValueError()

    if full is False and send_contiguous is True:
        # We can actually check that we haven't sent the entire storage as storage not accessed by the tensor are not sent
        if dist.get_rank(p2p.pg) == 0:
            # Check that the first element in the storages don't correspond (because they are not support to be communicated when the tensor is not full).

            print(tensor_to_send.untyped_storage()[:4], tensor_travelled_back_and_forth.untyped_storage()[:4])
            print(tensor_to_send.as_strided(size=(1,), stride=(1,), storage_offset=0))
            print(tensor_travelled_back_and_forth.as_strided(size=(1,), stride=(1,), storage_offset=0))
            assert not torch.allclose(
                tensor_to_send.as_strided(size=(1,), stride=(1,), storage_offset=0),
                tensor_travelled_back_and_forth.as_strided(size=(1,), stride=(1,), storage_offset=0),
            )

    parallel_context.destroy()