test_oss_adascale.py 4.71 KB
Newer Older
Min Xu's avatar
Min Xu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test AdaScale with OSS. """

from statistics import mean
import tempfile

import numpy as np
import pytest
import torch
from torch import Tensor
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear, Sequential
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD

25
26
from fair_dev.testing.golden_testing_data import adascale_test_data
from fair_dev.testing.testing import skip_if_single_gpu
Min Xu's avatar
Min Xu committed
27
from fairscale.optim import OSS, AdaScale, AdaScaleWrapper
Min Xu's avatar
Min Xu committed
28
29
30
31
32
33
34
35
36
37
38
39


def _dist_init(rank, world_size, tempfile_name, backend):
    url = "file://" + tempfile_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def _test_basic_func(rank, world_size, tempfile_name, test_case, oss, model=None):
    _dist_init(rank, world_size, tempfile_name, backend="nccl")

    if model is None:
40
41
42
        model = Linear(2, 2)
        model.bias.data.fill_(0.0)

Min Xu's avatar
Min Xu committed
43
44
    model.to("cuda")
    model = DDP(model, device_ids=[rank])
Min Xu's avatar
Min Xu committed
45
46
47

    assert oss in ["none", "ada-oss", "wrapper-oss", "oss-wrapper"]
    if oss == "ada-oss":
Min Xu's avatar
Min Xu committed
48
        optim = AdaScale(OSS(model.parameters(), SGD, lr=0.1))
Min Xu's avatar
Min Xu committed
49
50
51
52
    elif oss == "wrapper-oss":
        optim = AdaScaleWrapper(model.parameters(), optim_cls=OSS, optim=SGD, lr=0.1)
    elif oss == "oss-wrapper":
        optim = OSS(model.parameters(), AdaScaleWrapper, optim_cls=SGD, lr=0.1)
Min Xu's avatar
Min Xu committed
53
    else:
Min Xu's avatar
Min Xu committed
54
        assert oss == "none"
Min Xu's avatar
Min Xu committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        optim = AdaScale(SGD(model.parameters(), lr=0.1))

    if "input" in test_case:
        inputs = [test_case["input"]]
    else:
        inputs = test_case["inputs"]

    for in_data in inputs:
        in_data = Tensor(in_data[rank]).cuda()
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

Min Xu's avatar
Min Xu committed
69
    if "expected_gain" in test_case:
70
71
72
        assert np.allclose(optim.gain(), test_case["expected_gain"]), "{} vs {}".format(
            optim.gain(), test_case["expected_gain"]
        )
Min Xu's avatar
Min Xu committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    if "expected_mean_weight" in test_case:
        mean_weight = mean([model.module[i].weight.data.mean().item() for i in range(4)])
        assert np.allclose(mean_weight, test_case["expected_mean_weight"]), mean_weight

    dist.destroy_process_group()


@skip_if_single_gpu
@pytest.mark.parametrize("test_case", adascale_test_data)
def test_basic(test_case):
    """Test adascale with DDP + OSS with trivial model"""
    world_size = 2
    temp_file_name = tempfile.mkstemp()[1]

Min Xu's avatar
Min Xu committed
88
    mp.spawn(_test_basic_func, args=(world_size, temp_file_name, test_case, "ada-oss"), nprocs=world_size, join=True)
Min Xu's avatar
Min Xu committed
89
90
91


@skip_if_single_gpu
Min Xu's avatar
Min Xu committed
92
@pytest.mark.parametrize("oss", ["none", "ada-oss", "wrapper-oss", "oss-wrapper"])
Min Xu's avatar
Min Xu committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
def test_sequential(oss):
    """Test adascale with DDP + OSS with a sequential model"""
    world_size = 2
    temp_file_name = tempfile.mkstemp()[1]

    # Run multiple iterations, check the gain for both oss and non-oss cases.
    #
    # The inputs are picked arbitrarily. I used vectors that are orthogonal.
    #
    # The gain and mean_weight values are recorded from my testing and used here
    # to ensure their value is unchanged from commit to commit unless we can
    # explain why.
    test_case = {
        "inputs": [[[1.0, 0], [0, 1.0]], [[0, 1.0], [1.0, 0]]],
        "expected_gain": 1.0335265132125744,
        "expected_mean_weight": 52.92657661437988,
    }

Min Xu's avatar
Min Xu committed
111
112
113
114
115
116
117
118
    if oss == "oss-wrapper":
        # When OSS wraps AdaScale, the training is numerically different
        # and it exists only to enable future research. So we don't check
        # the gain (OSS doesn't have a gain() function, different rank's
        # gains are different). We just ensure the mean_weight is expected.
        del test_case["expected_gain"]
        test_case["expected_mean_weight"] = 94.93386840820312

Min Xu's avatar
Min Xu committed
119
120
121
122
123
124
125
126
127
128
129
130
    # The model.
    model = Sequential(
        Linear(2, 3, bias=False), Linear(3, 4, bias=False), Linear(4, 5, bias=False), Linear(5, 6, bias=False)
    )

    # Weights need to be fixed for deterministic gain.
    model[0].weight.data.copy_(Tensor(range(6)).reshape(3, 2) / mean(range(6)))
    model[1].weight.data.copy_(Tensor(range(12)).reshape(4, 3) / mean(range(12)))
    model[2].weight.data.copy_(Tensor(range(20)).reshape(5, 4) / mean(range(20)))
    model[3].weight.data.copy_(Tensor(range(30)).reshape(6, 5) / mean(range(30)))

    mp.spawn(_test_basic_func, args=(world_size, temp_file_name, test_case, oss, model), nprocs=world_size, join=True)