test_fsdp_overlap.py 8.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test FSDP and ensure expected overlapping between all_gather and forward. """

from statistics import mean
import time
from unittest.mock import patch

import pytest
import torch
from torch.cuda import Event
import torch.multiprocessing as mp
import torch.nn as nn

from fairscale.nn import enable_wrap, wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import (
    dist_init,
    get_cycles_per_ms,
    skip_if_single_gpu,
    teardown,
    temp_files_ctx,
    torch_version,
)


class Layer(nn.Module):
    def __init__(self, compute_cycles, has_params: bool):
        super().__init__()
        self.sleep_cycles = compute_cycles
        self.optional_param = None
        if has_params:
            self.optional_param = nn.Parameter(torch.rand(1))

    def forward(self, x):
        # Get 2 events.
        self.e1 = Event(enable_timing=True)
        self.e2 = Event(enable_timing=True)

        # Record the fake forward compute time.
        self.e1.record()
        if self.sleep_cycles > 0:
            torch.cuda._sleep(self.sleep_cycles)
        if self.optional_param is not None:
            x = x + self.optional_param  # force the param to be part of the graph
        self.e2.record()
        return x

    def get_time(self):
        # return the recorded duration.
        return self.e1.elapsed_time(self.e2)


def _create_model(fsdp_config, compute_cycles, has_params: bool):
    with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
        model = wrap(
            nn.Sequential(
                wrap(Layer(compute_cycles, has_params)),
                wrap(Layer(compute_cycles, has_params)),
                wrap(Layer(compute_cycles, has_params)),
                wrap(Layer(compute_cycles, has_params)),
            )
        ).cuda()
    return model


class Min10:
    def __init__(self):
        self.data = []

    def add(self, new_data):
        if len(self.data) < 10:
            self.data.append(new_data)
        else:
            self.data = sorted(self.data)
            if new_data < self.data[-1]:
                self.data[-1] = new_data

    def avg(self):
        return mean(self.data)


def _distributed_worker(
    gpu_id, world_size, fsdp_config, tempfile, tempfile_rpc,
):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, tempfile, tempfile_rpc)
    assert result, "Dist init failed"

    # Save the original torch.distributed.all_gather function since we will
    # patch it to include an artificial delay.
    orig_all_gather = torch.distributed.all_gather

    def run(compute_cycles, all_gather_cycles):
        has_params = all_gather_cycles > 0
        model = _create_model(fsdp_config, compute_cycles, has_params)

        # Get the input and sets the input's requires_grad to True because
        # we have a fake compute in the forward pass.
        batch = torch.rand(1).cuda()
        batch.requires_grad = True

        # We run 20 iterations but only collect timing data from the minimal 10
        # data points because nondeterministic system events can disturb the timing.
        cpu_iter = Min10()
        cpu_wait = Min10()
        gpu_compute = Min10()
        gpu_total = Min10()
        for _ in range(20):
            # Get two events for measuring the overall time.
            e1 = Event(enable_timing=True)
            e2 = Event(enable_timing=True)

            cpu_start = time.process_time()

            all_gather_called = False

            def _delayed_all_gather(*args, **kwargs):
                nonlocal all_gather_called
                all_gather_called = True
                torch.cuda._sleep(all_gather_cycles)
                return orig_all_gather(*args, **kwargs)

            # forward pass
            #
            # Even though both e1 & e2 are on the compute stream, since
            # compute depends on all_gather, e2-e1 includes all_gather time.
            e1.record()
            with patch("torch.distributed.all_gather", _delayed_all_gather):
                out = model(batch)
                if has_params and world_size > 1:
                    assert all_gather_called
                else:
                    assert not all_gather_called
            e2.record()

            # backward pass
            out.backward()
            if torch_version() >= (1, 7, 0):
                model.zero_grad(set_to_none=True)
            else:
                for p in model.parameters():
                    p.grad = None

            cpu_iter_time = time.process_time() - cpu_start

            # wait for gpu
            out.item()
            cpu_wait_for_gpu_time = time.process_time() - cpu_start - cpu_iter_time

            # get sum of the compute time
            times = []
            for mod in model.modules():
                if not isinstance(mod, Layer):
                    continue
                times.append(mod.get_time())

            # get gpu compute + all_gather time
            overall_gpu_time = e1.elapsed_time(e2)

            cpu_iter.add(cpu_iter_time)
            cpu_wait.add(cpu_wait_for_gpu_time)
            gpu_compute.add(sum(times))
            gpu_total.add(overall_gpu_time)

        del model
        return {
            "cpu_iter": cpu_iter.avg(),
            "cpu_wait": cpu_wait.avg(),
            "gpu_compute": gpu_compute.avg(),
            "gpu_total": gpu_total.avg(),
        }

    sleep_cycles = int(100 * get_cycles_per_ms())

    e1 = run(0, 0)  # no compute, no all-gather
    e2 = run(0, sleep_cycles)  # no compute, only all-gather
    e3 = run(sleep_cycles, 0)  # only compute, no all-gather
    e4 = run(sleep_cycles, sleep_cycles)  # both compute and all-gather
    debug_string = f"\nrank{rank}:\n  e1: {e1}\n  e2: {e2}\n  e3: {e3}\n  e4: {e4}"
    print(debug_string)

    # Check the cpu/gpu timing. CPU should run ahead of GPU. Therefore, cpu-gpu
    # wait should be long, except when there is no real work on GPU.
    #
    # If the assertions fail below, we likely have a cpu-gpu wait in the forward/backward pass.
    short = [e1["cpu_iter"], e2["cpu_iter"], e3["cpu_iter"], e4["cpu_iter"], e1["cpu_wait"]]
    long = [e3["cpu_wait"], e4["cpu_wait"]]
    if world_size == 1:
        short.append(e2["cpu_wait"])  # all gather should not be happening.
    else:
        long.append(e2["cpu_wait"])  # all gather should happen and prolong the cpu-gpu wait.
    for s in short:
        for l in long:
            # 10X longer is a safe margin, since the GPU work timing is around 100X more
            # of that of the CPU.
            assert s * 10 < l, f"{s} * 10 < {l} in " + debug_string

    # Check the GPU timing.
    short = [e1["gpu_compute"], e1["gpu_total"], e2["gpu_compute"]]
    long = [e3["gpu_compute"], e3["gpu_total"], e4["gpu_compute"], e4["gpu_total"]]
    if world_size == 1:
        short.append(e2["gpu_total"])  # all gather should not be happening.
    else:
        long.append(e2["gpu_total"])  # all gather should happen and prolong the cpu-gpu wait.
    for s in short:
        for l in long:
            # 10X longer is a safe margin, since the time is around 100X longer
            # when there is work on GPU vs. no work.
            assert s * 10 < l, f"{s} * 10 < {l} in " + debug_string

    # Check the GPU overlapping when there is all-gather.
    if world_size > 1:
        compute_only = e3["gpu_compute"]
        all_gather_only = e2["gpu_total"]
        both = e4["gpu_total"]
        assert compute_only + all_gather_only > 1.1 * both, (
            f"{compute_only} + {all_gather_only} > 1.1 * {both} in " + debug_string
        )

    teardown()


@skip_if_single_gpu
@pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
@pytest.mark.parametrize("mixed", ["mixed", "full"])
def test_forward_overlap(world_size, flatten, mixed):
    fsdp_config = {
        "flatten_parameters": flatten == "flatten",
        "mixed_precision": mixed == "mixed",
    }
    with temp_files_ctx(2) as temp_files:
        mp.spawn(
            _distributed_worker, (world_size, fsdp_config, temp_files[0], temp_files[1]), nprocs=world_size,
        )