test_zero_context_ancestry.py 3.64 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
'''Copyright The Microsoft DeepSpeed Team'''

import torch
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.accelerator import get_accelerator

from utils import setup_serial_env
from unit.common import DistributedTest

config = {
    "train_batch_size": 1,
    "steps_per_print": 1,
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 0.00015
        }
    },
    "fp16": {
        "enabled": True,
        "loss_scale": 138.
    },
    "zero_optimization": {
        "stage": 3,
        "stage3_param_persistence_threshold": 1,
    }
}


# test that sub-classes get params that aren't prematurely partitioned and thus requiring gathering
# fixed by https://github.com/microsoft/DeepSpeed/pull/1202
class GrandPa(torch.nn.Module):
    def __init__(self, *args):
        super().__init__(*args)
        self.param_grandpa = torch.nn.Parameter(torch.ones(5))
        self.param_grandpa.data = (self.param_grandpa.data +
                                   1).data  # test param is not yet partitioned


class Pa(GrandPa):
    def __init__(self, *args):
        super().__init__(*args)
        self.param_pa = torch.nn.Parameter(torch.ones(5))
        self.param_pa.data = (self.param_pa.data +
                              1).data  # test param is not yet partitioned
        self.param_grandpa.data = (self.param_grandpa.data +
                                   1).data  # test param is not yet partitioned


class Son(Pa):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.ones(5))
        self.param.data = (self.param.data + 1).data  # test param is not yet partitioned
        self.param_pa.data = (self.param_pa.data +
                              1).data  # test param is not yet partitioned
        self.param_grandpa.data = (self.param_grandpa.data +
                                   1).data  # test param is not yet partitioned


class TestSerialParamInit(DistributedTest):
    world_size = 1
    init_distributed = False
    set_dist_env = False

    def test_subclass_param_init(self):
        setup_serial_env()
        with deepspeed.zero.Init(config=config):
            model = Son().cpu()

        # test that all params have been partitioned
        assert model.param_grandpa.ds_status == ZeroParamStatus.NOT_AVAILABLE
        assert model.param_pa.ds_status == ZeroParamStatus.NOT_AVAILABLE
        assert model.param.ds_status == ZeroParamStatus.NOT_AVAILABLE

        # test that the weights manipulation during each __init__ worked in all w/o needing gathering
        ones = torch.ones(5).half().to(get_accelerator().device_name())
        with deepspeed.zero.GatheredParameters(list(model.parameters(recurse=False))):
            assert torch.equal(model.param, ones + 1)
            assert torch.equal(model.param_pa, ones + 2)
            assert torch.equal(model.param_grandpa, ones + 3)


class TestDSInitWZinit(DistributedTest):
    world_size = 2

    def test(self):
        ds_config = {
            "train_batch_size": 2,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.00015
                }
            }
        }

        class Model(torch.nn.Module):
            def __init__(self):
                super(Model, self).__init__()
                self.linear = torch.nn.Linear(4, 4)

            def magic(self):
                return 42

        with deepspeed.zero.Init():
            model = Model()
            engine, *_ = deepspeed.initialize(model=model, config=ds_config, model_parameters=model.parameters())
        assert engine.magic() == 42