test_reduce_scatter_bucketer.py 4.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import functools
import itertools
import sys
import unittest
from unittest import mock

from parameterized import parameterized
import torch

from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes


def rename_test(testcase_func, param_num, param):
    return "%s_%s" % (testcase_func.__name__, parameterized.to_safe_name(str(param.args)),)


CONFIG_OPTIONS = [
    [dict(zip(["bucket_cap_mb", "shard_size"], config))] for config in itertools.product([0, 0.25], [1, 262144])
]


class TestReduceScatterBucketer(unittest.TestCase):
    # TODO(sshleifer): check if possible to reuse `DistributedTest, spawn_and_init`.
    def setUp(self):
        major, minor = torch.__version__.split(".")[:2]
        major, minor = int(major), int(minor)
        if major < 1 or (major == 1 and minor < 6):
            raise unittest.SkipTest("Need pytorch version >= 1.6 due to reduce_scatter")
        if not torch.cuda.is_available():
            raise unittest.SkipTest("CUDA not available, skipping test")
        if sys.platform == "win32":
            raise unittest.SkipTest("NCCL doesn't support Windows, skipping test")
        if torch.cuda.device_count() < 2:
            raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")

    @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
    def test_reduce_scatter(self, config):
        spawn_and_init(functools.partial(self._test_reduce_scatter, **config))

    @staticmethod
    def _test_reduce_scatter(rank, group, bucket_cap_mb=None, shard_size=None):
        bucketer = ReduceScatterBucketer(bucket_cap_mb=bucket_cap_mb)
        world_size = group.size()

        tensors = [torch.ones(shard_size).cuda() for _ in range(world_size)]
        tensors[rank].fill_(0)

        input_bytes = shard_size * world_size * 4
        bucket_bytes = bucket_cap_mb * 1024 * 1024

        callback = mock.MagicMock()
        bucketer.reduce_scatter_async(tensors, group, callback_fn=callback)

        if bucket_cap_mb > 0 and input_bytes < bucket_bytes:
            assert callback.call_count == 0
            bucketer.flush()
        assert callback.call_count == 1

        result = callback.call_args[0][0]  # get first positional arg
        assert torch.is_tensor(result), result
        assert torch.all(result == (world_size - 1))

    def test_out_of_order_reduction(self):
        spawn_and_init(self._test_out_of_order_reduction)

    @staticmethod
    def _test_out_of_order_reduction(rank, group):
        bucketer = ReduceScatterBucketer(bucket_cap_mb=0.25)
        world_size = group.size()

        small_tensors = [torch.ones(1).cuda() for _ in range(world_size)]
        big_tensors = [torch.ones(262144).cuda() for _ in range(world_size)]
        more_small_tensors = [torch.ones(2).cuda() for _ in range(world_size)]

        callback1 = mock.MagicMock()
        callback2 = mock.MagicMock()
        callback3 = mock.MagicMock()

        bucketer.reduce_scatter_async(small_tensors, group, callback_fn=callback1)
        assert callback1.call_count == 0
        bucketer.reduce_scatter_async(big_tensors, group, callback_fn=callback2)
        assert callback1.call_count == 0
        assert callback2.call_count == 1
        bucketer.reduce_scatter_async(more_small_tensors, group, callback_fn=callback3)
        assert callback1.call_count == 0
        assert callback2.call_count == 1
        assert callback3.call_count == 0

        bucketer.flush()
        assert callback1.call_count == 1
        assert callback2.call_count == 1
        assert callback3.call_count == 1


def spawn_and_init(fn, args=None, **spawn_kwargs):
    if args is None:
        args = ()
    run_fn = functools.partial(init_and_run, fn, args)
    spawn_for_all_world_sizes(run_fn, **spawn_kwargs)


def init_and_run(fn, args, rank, world_size, filename, filename_rpc):
    dist_init(rank, world_size, filename, filename_rpc)
    group = torch.distributed.new_group()
    fn(rank, group, *args)


if __name__ == "__main__":
    unittest.main()