test_dist.py 1 KB
Newer Older
Jeff Rasley's avatar
Jeff Rasley committed
1
import torch
Shaden Smith's avatar
Shaden Smith committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch.distributed as dist

from common import distributed_test

import pytest


@distributed_test(world_size=3)
def test_init():
    assert dist.is_initialized()
    assert dist.get_world_size() == 3
    assert dist.get_rank() < 3


# Demonstration of pytest's paramaterization
@pytest.mark.parametrize('number,color', [(1138, 'purple')])
def test_dist_args(number, color):
    """Outer test function with inputs from pytest.mark.parametrize(). Uses a distributed
    helper function.
    """
    @distributed_test(world_size=2)
    def _test_dist_args_helper(x, color='red'):
        assert dist.get_world_size() == 2
        assert x == 1138
        assert color == 'purple'

    """Ensure that we can parse args to distributed_test decorated functions. """
    _test_dist_args_helper(number, color=color)
Jeff Rasley's avatar
Jeff Rasley committed
30
31
32
33
34
35
36
37


@distributed_test(world_size=2)
def test_dist_allreduce():
    x = torch.ones(1, 3) * (dist.get_rank() + 1)
    result = torch.ones(1, 3) * 3
    dist.all_reduce(x)
    assert torch.all(x == result)