"vscode:/vscode.git/clone" did not exist on "d71661aadceabd56e175543f8bf14af6d1d36f90"
test_dist.py 1.11 KB
Newer Older
Jeff Rasley's avatar
Jeff Rasley committed
1
import torch
Shaden Smith's avatar
Shaden Smith committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch.distributed as dist

from common import distributed_test

import pytest


@distributed_test(world_size=3)
def test_init():
    assert dist.is_initialized()
    assert dist.get_world_size() == 3
    assert dist.get_rank() < 3


# Demonstration of pytest's paramaterization
@pytest.mark.parametrize('number,color', [(1138, 'purple')])
def test_dist_args(number, color):
    """Outer test function with inputs from pytest.mark.parametrize(). Uses a distributed
    helper function.
    """
    @distributed_test(world_size=2)
    def _test_dist_args_helper(x, color='red'):
        assert dist.get_world_size() == 2
        assert x == 1138
        assert color == 'purple'

    """Ensure that we can parse args to distributed_test decorated functions. """
    _test_dist_args_helper(number, color=color)
Jeff Rasley's avatar
Jeff Rasley committed
30
31


32
@distributed_test(world_size=[1, 2, 4])
Jeff Rasley's avatar
Jeff Rasley committed
33
def test_dist_allreduce():
34
35
36
    x = torch.ones(1, 3).cuda() * (dist.get_rank() + 1)
    sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2
    result = torch.ones(1, 3).cuda() * sum_of_ranks
Jeff Rasley's avatar
Jeff Rasley committed
37
38
    dist.all_reduce(x)
    assert torch.all(x == result)