test_wrap.py 6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import unittest
from unittest import mock

import torch
import torch.nn as nn
import torch.nn.functional as F

from fairscale.nn import FullyShardedDataParallel as FSDP
from fairscale.nn import auto_wrap, enable_wrap, wrap
from fairscale.utils.testing import DummyProcessGroup


class TestAutoWrap(unittest.TestCase):
    def setUp(self) -> None:
        version = torch.__version__.split(".")[:2]
        major, minor = int(version[0]), int(version[1])
        if major < 1 or (major == 1 and minor < 6):
            raise unittest.SkipTest("Need pytorch version >= 1.6 due to autocast")
        self.process_group = DummyProcessGroup(rank=0, size=1)

    def test_wrap(self):
        with enable_wrap(flatten_parameters=False, process_group=self.process_group):
            layer = wrap(nn.Linear(5, 5))
        assert isinstance(layer, FSDP)
        assert layer.flatten_parameters is False

    def test_wrap_disabled_outside_context(self):
        layer = wrap(nn.Linear(5, 5))
        assert isinstance(layer, nn.Linear)

    def test_wrap_override_defaults(self):
        with enable_wrap(flatten_parameters=False, process_group=self.process_group):
            layer = wrap(nn.Linear(5, 5), flatten_parameters=True)
        assert isinstance(layer, FSDP)
        assert layer.flatten_parameters

    def test_auto_wrap(self):
        """
        Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
        ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
        """
        with enable_wrap(process_group=self.process_group, flatten_parameters=False):
            sequential = nn.Sequential(
                nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
            )
            model = auto_wrap(sequential, min_num_params=40)
        assert isinstance(model, FSDP)
        assert isinstance(model.module[0], nn.Linear)
        assert isinstance(model.module[1], nn.Linear)
        assert isinstance(model.module[2], FSDP)
        assert isinstance(model.module[2].module[0], nn.Linear)
        assert isinstance(model.module[2].module[1], nn.Linear)

    def test_auto_wrap_preset_exclude_wrap(self):
        """
        Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
        min_num_params.
        """
        with enable_wrap(process_group=self.process_group, flatten_parameters=False):
            sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
            model = auto_wrap(sequential, min_num_params=40)
        assert isinstance(model, nn.ModuleList)
        assert isinstance(model[0], nn.Linear)
        assert isinstance(model[1], nn.Linear)

    def test_auto_wrap_preset_exclude_wrap_include_children(self):
        """
        Test to ensure excluded modules are not wrapped, but children are if param size is greater than
        min_num_params
        """
        with enable_wrap(process_group=self.process_group, flatten_parameters=False):
            sequential = nn.ModuleList([nn.Linear(10, 10)])
            model = auto_wrap(sequential, min_num_params=40)
        assert isinstance(model, nn.ModuleList)
        assert isinstance(model[0], FSDP)

    def test_auto_wrap_preset_blocklist(self):
        """
        Test to ensure blocklisted modules are not wrapped, and children are not wrapped.
        """
        with enable_wrap(process_group=self.process_group, flatten_parameters=False):
            sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1))
            model = auto_wrap(sequential, min_num_params=40)
        assert isinstance(model.module[0], FSDP)
        # Assert children of multihead attention are not wrapped
        assert isinstance(model.module[1], nn.MultiheadAttention)
        assert isinstance(model.module[1].out_proj, nn.Linear)

    def test_auto_wrap_preset_blocklist_custom(self):
        """
        Test to ensure blocklisted modules are not wrapped.
        """
        with enable_wrap(module_blocklist=[nn.Linear], process_group=self.process_group, flatten_parameters=False):
            sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)]))
            model = auto_wrap(sequential, min_num_params=40)
        # Model was wrapped in FSDP as no inner modules were wrapped.
        assert isinstance(model, FSDP)
        assert isinstance(model.module[0], nn.Linear)
        assert isinstance(model.module[1], nn.ModuleList)

    # todo: currently complains that address is in use, not sure why since I clear the proc group.
    # def test_auto_wrap_smoke(self):
    #     self._auto_wrap_smoke_test(enable_mixed_precision=False)

    def test_auto_wrap_smoke_autocast(self):
        """
        Ensure we can do a forward/backward through an auto-wrapped model.
        """
        self._auto_wrap_smoke_test(enable_mixed_precision=True)

    @mock.patch.dict(os.environ, {"MASTER_ADDR": "localhost", "MASTER_PORT": "12345"}, clear=True)
    @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
    def _auto_wrap_smoke_test(self, enable_mixed_precision):
        from torch.cuda.amp import autocast

        device = torch.device("cuda")
        torch.cuda.set_device(0)
        torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

        with enable_wrap(mixed_precision=enable_mixed_precision):
            sequential = nn.Sequential(
                nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
            )
            model = auto_wrap(sequential, min_num_params=40)
        model.to(device)
        input = torch.rand((1, 5), dtype=torch.float).to(device)

        with autocast(enabled=enable_mixed_precision):
            output = model(input)
            loss = F.mse_loss(input, output)
        loss.backward()
        torch.distributed.destroy_process_group()