test_dist.py 815 Bytes
Newer Older
Shaden Smith's avatar
Shaden Smith committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch.distributed as dist

from common import distributed_test

import pytest


@distributed_test(world_size=3)
def test_init():
    assert dist.is_initialized()
    assert dist.get_world_size() == 3
    assert dist.get_rank() < 3


# Demonstration of pytest's paramaterization
@pytest.mark.parametrize('number,color', [(1138, 'purple')])
def test_dist_args(number, color):
    """Outer test function with inputs from pytest.mark.parametrize(). Uses a distributed
    helper function.
    """
    @distributed_test(world_size=2)
    def _test_dist_args_helper(x, color='red'):
        assert dist.get_world_size() == 2
        assert x == 1138
        assert color == 'purple'

    """Ensure that we can parse args to distributed_test decorated functions. """
    _test_dist_args_helper(number, color=color)