test_reduce_scatter_bucketer.py 4.18 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import functools
import itertools
import sys
import unittest
from unittest import mock

from parameterized import parameterized
import torch

15
from fairscale.utils import torch_version
16
17
18
19
20
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes


def rename_test(testcase_func, param_num, param):
21
22
23
24
    return "%s_%s" % (
        testcase_func.__name__,
        parameterized.to_safe_name(str(param.args)),
    )
25
26
27
28
29
30
31
32
33
34


CONFIG_OPTIONS = [
    [dict(zip(["bucket_cap_mb", "shard_size"], config))] for config in itertools.product([0, 0.25], [1, 262144])
]


class TestReduceScatterBucketer(unittest.TestCase):
    # TODO(sshleifer): check if possible to reuse `DistributedTest, spawn_and_init`.
    def setUp(self):
35
        major, minor, _ = torch_version()
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        if major < 1 or (major == 1 and minor < 6):
            raise unittest.SkipTest("Need pytorch version >= 1.6 due to reduce_scatter")
        if not torch.cuda.is_available():
            raise unittest.SkipTest("CUDA not available, skipping test")
        if sys.platform == "win32":
            raise unittest.SkipTest("NCCL doesn't support Windows, skipping test")
        if torch.cuda.device_count() < 2:
            raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")

    @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
    def test_reduce_scatter(self, config):
        spawn_and_init(functools.partial(self._test_reduce_scatter, **config))

    @staticmethod
    def _test_reduce_scatter(rank, group, bucket_cap_mb=None, shard_size=None):
        bucketer = ReduceScatterBucketer(bucket_cap_mb=bucket_cap_mb)
        world_size = group.size()

        tensors = [torch.ones(shard_size).cuda() for _ in range(world_size)]
        tensors[rank].fill_(0)

        input_bytes = shard_size * world_size * 4
        bucket_bytes = bucket_cap_mb * 1024 * 1024

        callback = mock.MagicMock()
        bucketer.reduce_scatter_async(tensors, group, callback_fn=callback)

        if bucket_cap_mb > 0 and input_bytes < bucket_bytes:
            assert callback.call_count == 0
            bucketer.flush()
        assert callback.call_count == 1

        result = callback.call_args[0][0]  # get first positional arg
        assert torch.is_tensor(result), result
        assert torch.all(result == (world_size - 1))

    def test_out_of_order_reduction(self):
        spawn_and_init(self._test_out_of_order_reduction)

    @staticmethod
    def _test_out_of_order_reduction(rank, group):
        bucketer = ReduceScatterBucketer(bucket_cap_mb=0.25)
        world_size = group.size()

        small_tensors = [torch.ones(1).cuda() for _ in range(world_size)]
        big_tensors = [torch.ones(262144).cuda() for _ in range(world_size)]
        more_small_tensors = [torch.ones(2).cuda() for _ in range(world_size)]

        callback1 = mock.MagicMock()
        callback2 = mock.MagicMock()
        callback3 = mock.MagicMock()

        bucketer.reduce_scatter_async(small_tensors, group, callback_fn=callback1)
        assert callback1.call_count == 0
        bucketer.reduce_scatter_async(big_tensors, group, callback_fn=callback2)
        assert callback1.call_count == 0
        assert callback2.call_count == 1
        bucketer.reduce_scatter_async(more_small_tensors, group, callback_fn=callback3)
        assert callback1.call_count == 0
        assert callback2.call_count == 1
        assert callback3.call_count == 0

        bucketer.flush()
        assert callback1.call_count == 1
        assert callback2.call_count == 1
        assert callback3.call_count == 1


def spawn_and_init(fn, args=None, **spawn_kwargs):
    if args is None:
        args = ()
    run_fn = functools.partial(init_and_run, fn, args)
    spawn_for_all_world_sizes(run_fn, **spawn_kwargs)


def init_and_run(fn, args, rank, world_size, filename, filename_rpc):
    dist_init(rank, world_size, filename, filename_rpc)
    group = torch.distributed.new_group()
    fn(rank, group, *args)


if __name__ == "__main__":
    unittest.main()