"vscode:/vscode.git/clone" did not exist on "f514ab64baae4862488e068c38829286ae9296fc"
test_zero_context.py 3.92 KB
Newer Older
Samyam Rajbhandari's avatar
Samyam Rajbhandari committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import torch
import pytest

import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus

from common import distributed_test


def setup_serial_env():
    # Setup for a serial run
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29503'
    os.environ['LOCAL_RANK'] = '0'
    os.environ['RANK'] = '0'
    os.environ['WORLD_SIZE'] = '1'


def test_scattered_init_dist():
    setup_serial_env()
    assert not torch.distributed.is_initialized()
    with deepspeed.zero.Init():
        assert torch.distributed.is_initialized()


@distributed_test(world_size=2)
def test_scatter_gather():
    with deepspeed.zero.Init():
        l = torch.nn.Linear(6, 3)
    assert l.weight.ds_status == ZeroParamStatus.NOT_AVAILABLE
    assert l.weight.numel() == 1

    # Ensure there is no impact outside the context
    l2 = torch.nn.Linear(6, 3)
    assert not hasattr(l2.weight, 'ds_status')
    assert l2.weight.numel() == l2.in_features * l2.out_features

    with deepspeed.zero.GatheredParameters(l.weight):
        assert l.weight.ds_status == ZeroParamStatus.AVAILABLE
        assert l.weight.numel() == l.in_features * l.out_features


@distributed_test(world_size=2)
def test_gather_update():
    with deepspeed.zero.Init():
        l = torch.nn.Linear(4, 2)
    assert l.weight.ds_status == ZeroParamStatus.NOT_AVAILABLE

    # Gather and make a change
    with deepspeed.zero.GatheredParameters(l.weight, modifier_rank=1):
        assert l.weight.ds_status == ZeroParamStatus.AVAILABLE
        if torch.distributed.get_rank() == 1:
            with torch.no_grad():
                l.weight.zero_()

    # should now be scattered again

    # Now gather again and ensure the change is global
    with deepspeed.zero.GatheredParameters(l.weight):
        # all ranks compare
        assert torch.equal(l.weight, torch.zeros_like(l.weight))


@pytest.mark.skip('WIP')
def test_external_param():
    setup_serial_env()

    print()

    class ExtLinear(torch.nn.Module):
        def __init__(self, dim=10, copycat=None):
            super().__init__()
            self.dim = dim
            self.linear = torch.nn.Linear(dim, dim)
            if copycat is not None:
                with deepspeed.zero.GatheredParameters(self.linear.weight,
                                                  modifier_rank=0), \
                     torch.no_grad():
                    self.linear.weight.copy_(copycat.linear.weight)

            if hasattr(self.linear.weight, 'ds_id'):
                print('registering')
                super().ds_register_external_parameter('samyam', self.linear.weight)

        def forward(self, input):
            yamsam = self.linear(input)
            if hasattr(self.linear.weight, 'ds_status'):
                assert self.linear.weight.ds_status == ZeroParamStatus.AVAILABLE
            jeff = torch.nn.functional.linear(yamsam, self.linear.weight)
            return jeff

    l1_base = ExtLinear().half().cuda()
    l2_base = ExtLinear().half().cuda()

    input = torch.rand(10).half().cuda()

    l1_base_out = l1_base(input.clone().detach())
    l2_base_out = l2_base(input.clone().detach())

    with deepspeed.zero.Init():
        l1_test = ExtLinear(copycat=l1_base).cuda()
        #l2_test = ExtLinear(copycat=l2_base).cuda()
        assert l1_test.linear.weight.ds_status == ZeroParamStatus.NOT_AVAILABLE

    # XXX l1 and l2 share their external parameter (l2.linear.weight)

    assert l1_test.linear.weight.ds_status == ZeroParamStatus.NOT_AVAILABLE
    l1_test_out = l1_test(input.clone().detach())
    #assert torch.allclose(l1_base_out, l1_test_out)

    #l2_test_out = l2_test(input.clone().detach())
    #assert torch.allclose(l2_base_out, l2_test_out)


def test_scatter_halftype():
    setup_serial_env()

    with deepspeed.zero.Init():
        l = torch.nn.Linear(10, 10)
        assert l.weight.ds_tensor.dtype == torch.float16

        y = torch.LongTensor([3, 3])
        assert y.dtype == torch.long