test_oss_adascale.py 4.62 KB
Newer Older
Min Xu's avatar
Min Xu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test AdaScale with OSS. """

from statistics import mean
import tempfile

import numpy as np
import pytest
import torch
from torch import Tensor
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear, Sequential
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD

Min Xu's avatar
Min Xu committed
25
from fairscale.optim import OSS, AdaScale, AdaScaleWrapper
Min Xu's avatar
Min Xu committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from fairscale.utils.golden_testing_data import adascale_test_data
from fairscale.utils.testing import skip_if_single_gpu


def _dist_init(rank, world_size, tempfile_name, backend):
    url = "file://" + tempfile_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def _test_basic_func(rank, world_size, tempfile_name, test_case, oss, model=None):
    _dist_init(rank, world_size, tempfile_name, backend="nccl")

    if model is None:
        model = Linear(2, 2, bias=False)
    model.to("cuda")
    model = DDP(model, device_ids=[rank])
Min Xu's avatar
Min Xu committed
43
44
45

    assert oss in ["none", "ada-oss", "wrapper-oss", "oss-wrapper"]
    if oss == "ada-oss":
Min Xu's avatar
Min Xu committed
46
        optim = AdaScale(OSS(model.parameters(), SGD, lr=0.1))
Min Xu's avatar
Min Xu committed
47
48
49
50
    elif oss == "wrapper-oss":
        optim = AdaScaleWrapper(model.parameters(), optim_cls=OSS, optim=SGD, lr=0.1)
    elif oss == "oss-wrapper":
        optim = OSS(model.parameters(), AdaScaleWrapper, optim_cls=SGD, lr=0.1)
Min Xu's avatar
Min Xu committed
51
    else:
Min Xu's avatar
Min Xu committed
52
        assert oss == "none"
Min Xu's avatar
Min Xu committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        optim = AdaScale(SGD(model.parameters(), lr=0.1))

    if "input" in test_case:
        inputs = [test_case["input"]]
    else:
        inputs = test_case["inputs"]

    for in_data in inputs:
        in_data = Tensor(in_data[rank]).cuda()
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

Min Xu's avatar
Min Xu committed
67
68
    if "expected_gain" in test_case:
        assert np.allclose(optim.gain(), test_case["expected_gain"]), optim.gain()
Min Xu's avatar
Min Xu committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

    if "expected_mean_weight" in test_case:
        mean_weight = mean([model.module[i].weight.data.mean().item() for i in range(4)])
        assert np.allclose(mean_weight, test_case["expected_mean_weight"]), mean_weight

    dist.destroy_process_group()


@skip_if_single_gpu
@pytest.mark.parametrize("test_case", adascale_test_data)
def test_basic(test_case):
    """Test adascale with DDP + OSS with trivial model"""
    world_size = 2
    temp_file_name = tempfile.mkstemp()[1]

Min Xu's avatar
Min Xu committed
84
    mp.spawn(_test_basic_func, args=(world_size, temp_file_name, test_case, "ada-oss"), nprocs=world_size, join=True)
Min Xu's avatar
Min Xu committed
85
86
87


@skip_if_single_gpu
Min Xu's avatar
Min Xu committed
88
@pytest.mark.parametrize("oss", ["none", "ada-oss", "wrapper-oss", "oss-wrapper"])
Min Xu's avatar
Min Xu committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
def test_sequential(oss):
    """Test adascale with DDP + OSS with a sequential model"""
    world_size = 2
    temp_file_name = tempfile.mkstemp()[1]

    # Run multiple iterations, check the gain for both oss and non-oss cases.
    #
    # The inputs are picked arbitrarily. I used vectors that are orthogonal.
    #
    # The gain and mean_weight values are recorded from my testing and used here
    # to ensure their value is unchanged from commit to commit unless we can
    # explain why.
    test_case = {
        "inputs": [[[1.0, 0], [0, 1.0]], [[0, 1.0], [1.0, 0]]],
        "expected_gain": 1.0335265132125744,
        "expected_mean_weight": 52.92657661437988,
    }

Min Xu's avatar
Min Xu committed
107
108
109
110
111
112
113
114
    if oss == "oss-wrapper":
        # When OSS wraps AdaScale, the training is numerically different
        # and it exists only to enable future research. So we don't check
        # the gain (OSS doesn't have a gain() function, different rank's
        # gains are different). We just ensure the mean_weight is expected.
        del test_case["expected_gain"]
        test_case["expected_mean_weight"] = 94.93386840820312

Min Xu's avatar
Min Xu committed
115
116
117
118
119
120
121
122
123
124
125
126
    # The model.
    model = Sequential(
        Linear(2, 3, bias=False), Linear(3, 4, bias=False), Linear(4, 5, bias=False), Linear(5, 6, bias=False)
    )

    # Weights need to be fixed for deterministic gain.
    model[0].weight.data.copy_(Tensor(range(6)).reshape(3, 2) / mean(range(6)))
    model[1].weight.data.copy_(Tensor(range(12)).reshape(4, 3) / mean(range(12)))
    model[2].weight.data.copy_(Tensor(range(20)).reshape(5, 4) / mean(range(20)))
    model[3].weight.data.copy_(Tensor(range(30)).reshape(6, 5) / mean(range(30)))

    mp.spawn(_test_basic_func, args=(world_size, temp_file_name, test_case, oss, model), nprocs=world_size, join=True)