test_oss_adascale.py 3.95 KB
Newer Older
Min Xu's avatar
Min Xu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test AdaScale with OSS. """

from statistics import mean
import tempfile

import numpy as np
import pytest
import torch
from torch import Tensor
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear, Sequential
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD

from fairscale.optim import OSS, AdaScale
from fairscale.utils.golden_testing_data import adascale_test_data
from fairscale.utils.testing import skip_if_single_gpu


def _dist_init(rank, world_size, tempfile_name, backend):
    url = "file://" + tempfile_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def _test_basic_func(rank, world_size, tempfile_name, test_case, oss, model=None):
    _dist_init(rank, world_size, tempfile_name, backend="nccl")

    if model is None:
        model = Linear(2, 2, bias=False)
    model.to("cuda")
    model = DDP(model, device_ids=[rank])
    if oss:
        # For now, we can only wrap AdaScale over OSS. If we do it the other way around,
        # AdaScale needs to take different parameter types, i.e. the parameter list, etc.
        optim = AdaScale(OSS(model.parameters(), SGD, lr=0.1))
    else:
        optim = AdaScale(SGD(model.parameters(), lr=0.1))

    if "input" in test_case:
        inputs = [test_case["input"]]
    else:
        inputs = test_case["inputs"]

    for in_data in inputs:
        in_data = Tensor(in_data[rank]).cuda()
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

    assert np.allclose(optim.gain(), test_case["expected_gain"]), optim.gain()

    if "expected_mean_weight" in test_case:
        mean_weight = mean([model.module[i].weight.data.mean().item() for i in range(4)])
        assert np.allclose(mean_weight, test_case["expected_mean_weight"]), mean_weight

    dist.destroy_process_group()


@skip_if_single_gpu
@pytest.mark.parametrize("test_case", adascale_test_data)
def test_basic(test_case):
    """Test adascale with DDP + OSS with trivial model"""
    world_size = 2
    temp_file_name = tempfile.mkstemp()[1]

    mp.spawn(_test_basic_func, args=(world_size, temp_file_name, test_case, True), nprocs=world_size, join=True)


@skip_if_single_gpu
@pytest.mark.parametrize("oss", [True, False])
def test_sequential(oss):
    """Test adascale with DDP + OSS with a sequential model"""
    world_size = 2
    temp_file_name = tempfile.mkstemp()[1]

    # Run multiple iterations, check the gain for both oss and non-oss cases.
    #
    # The inputs are picked arbitrarily. I used vectors that are orthogonal.
    #
    # The gain and mean_weight values are recorded from my testing and used here
    # to ensure their value is unchanged from commit to commit unless we can
    # explain why.
    test_case = {
        "inputs": [[[1.0, 0], [0, 1.0]], [[0, 1.0], [1.0, 0]]],
        "expected_gain": 1.0335265132125744,
        "expected_mean_weight": 52.92657661437988,
    }

    # The model.
    model = Sequential(
        Linear(2, 3, bias=False), Linear(3, 4, bias=False), Linear(4, 5, bias=False), Linear(5, 6, bias=False)
    )

    # Weights need to be fixed for deterministic gain.
    model[0].weight.data.copy_(Tensor(range(6)).reshape(3, 2) / mean(range(6)))
    model[1].weight.data.copy_(Tensor(range(12)).reshape(4, 3) / mean(range(12)))
    model[2].weight.data.copy_(Tensor(range(20)).reshape(5, 4) / mean(range(20)))
    model[3].weight.data.copy_(Tensor(range(30)).reshape(6, 5) / mean(range(30)))

    mp.spawn(_test_basic_func, args=(world_size, temp_file_name, test_case, oss, model), nprocs=world_size, join=True)