test_tie_weights.py 7.38 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import torch
from helpers.distributed_tensor import assert_tensor_equal_over_group
from helpers.exception import assert_fail_with
from helpers.utils import init_distributed, rerun_if_address_is_in_use
from nanotron import distributed as dist
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.tied_parameters import (
    get_tied_id_to_param,
    sync_tied_weights_gradients,
    tie_parameters,
)
from torch import nn


@rerun_if_address_is_in_use()
def test_tie_weight_in_same_device():
    init_distributed(tp=1, dp=1, pp=1)(_test_tie_weight_in_same_device)()


def _test_tie_weight_in_same_device(parallel_context: ParallelContext):
    model = nn.ModuleDict({"dense0": nn.Linear(10, 10, device="cuda"), "dense1": nn.Linear(10, 10, device="cuda")})

    # Tie weights/bias
    tie_parameters(
        root_module=model,
        ties=[("dense0.weight", (0,)), ("dense1.weight", (0,))],
        parallel_context=parallel_context,
        reduce_op=dist.ReduceOp.SUM,
    )
    tie_parameters(
        root_module=model,
        ties=[("dense0.bias", (0,)), ("dense1.bias", (0,))],
        parallel_context=parallel_context,
        reduce_op=dist.ReduceOp.SUM,
    )

    weight0 = model.get_parameter("dense0.weight")
    weight1 = model.get_parameter("dense1.weight")
    bias0 = model.get_parameter("dense0.bias")
    bias1 = model.get_parameter("dense1.bias")

    # We check that we use the same parameter for both linear layers
    assert id(weight0) == id(weight1)
    assert id(bias0) == id(bias1)

    parallel_context.destroy()


@rerun_if_address_is_in_use()
def test_tie_weight_in_different_device():
    init_distributed(tp=1, dp=1, pp=2)(_test_tie_weight_in_different_device)()


def _test_tie_weight_in_different_device(parallel_context: ParallelContext):
    if dist.get_rank(parallel_context.pp_pg) == 0:
        model = nn.ModuleDict(
            {
                "dense0": nn.Linear(10, 10, device="cuda"),
            }
        )
    else:
        model = nn.ModuleDict(
            {
                "dense1": nn.Linear(10, 10, device="cuda"),
            }
        )

    # Tie weights/bias
    tie_parameters(
        root_module=model,
        ties=[("dense0.weight", (0,)), ("dense1.weight", (1,))],
        parallel_context=parallel_context,
        reduce_op=dist.ReduceOp.SUM,
    )
    tie_parameters(
        root_module=model,
        ties=[("dense0.bias", (0,)), ("dense1.bias", (1,))],
        parallel_context=parallel_context,
        reduce_op=dist.ReduceOp.SUM,
    )

    group = parallel_context.world_ranks_to_pg[(0, 1)]

    # Check that model weights are not in fact synchronized
    if dist.get_rank(parallel_context.pp_pg) == 0:
        weight = model.dense0.weight
        bias = model.dense0.bias
    else:
        weight = model.dense1.weight
        bias = model.dense1.bias

    # Make sure that weight/bias are NanotronParameter and that they are tied
    assert isinstance(weight, NanotronParameter)
    assert weight.is_tied
    assert isinstance(bias, NanotronParameter)
    assert bias.is_tied

    # Weights/bias are not synced yet
    assert not assert_tensor_equal_over_group(weight, group=group, assert_=False)
    assert not assert_tensor_equal_over_group(bias, group=group, assert_=False)

    # Manually sync weights
    for (_, group_ranks), param in sorted(
        get_tied_id_to_param(
            parameters=model.parameters(),
            root_module=model,
        ).items(),
        key=lambda x: x[0],
    ):
        group = parallel_context.world_ranks_to_pg[group_ranks]
        dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group)

    # We check that we use the same parameter for both linear layers
    assert_tensor_equal_over_group(weight, group=group)
    assert_tensor_equal_over_group(bias, group=group)

    parallel_context.destroy()


@rerun_if_address_is_in_use()
def test_tie_weight_across_dp_is_impossible():
    init_distributed(tp=1, dp=2, pp=1)(_test_tie_weight_across_dp_is_impossible)()


def _test_tie_weight_across_dp_is_impossible(parallel_context: ParallelContext):
    if dist.get_rank(parallel_context.dp_pg) == 0:
        model = nn.ModuleDict(
            {
                "dense0": nn.Linear(10, 10, device="cuda"),
            }
        )
    else:
        model = nn.ModuleDict(
            {
                "dense1": nn.Linear(10, 10, device="cuda"),
            }
        )

    # Tie weights/bias
    with assert_fail_with(AssertionError):
        tie_parameters(
            root_module=model,
            ties=[("dense0.weight", (0,)), ("dense1.weight", (1,))],
            parallel_context=parallel_context,
            reduce_op=dist.ReduceOp.SUM,
        )
    with assert_fail_with(AssertionError):
        tie_parameters(
            root_module=model,
            ties=[("dense0.bias", (0,)), ("dense1.bias", (1,))],
            parallel_context=parallel_context,
            reduce_op=dist.ReduceOp.SUM,
        )

    parallel_context.destroy()


@rerun_if_address_is_in_use()
def test_tie_weight_in_different_device_have_gradients_synchronized():
    init_distributed(tp=1, dp=1, pp=2)(_test_tie_weight_in_different_device_have_gradients_synchronized)()


def _test_tie_weight_in_different_device_have_gradients_synchronized(parallel_context: ParallelContext):
    if dist.get_rank(parallel_context.pp_pg) == 0:
        model = nn.ModuleDict(
            {
                "dense0": nn.Linear(10, 10, device="cuda"),
            }
        )
    else:
        model = nn.ModuleDict(
            {
                "dense1": nn.Linear(10, 10, device="cuda"),
            }
        )

    # Tie weights/bias
    tie_parameters(
        root_module=model,
        ties=[("dense0.weight", (0,)), ("dense1.weight", (1,))],
        parallel_context=parallel_context,
        reduce_op=dist.ReduceOp.SUM,
    )
    tie_parameters(
        root_module=model,
        ties=[("dense0.bias", (0,)), ("dense1.bias", (1,))],
        parallel_context=parallel_context,
        reduce_op=dist.ReduceOp.SUM,
    )

    group = parallel_context.world_ranks_to_pg[(0, 1)]

    # Check that model weights are not in fact synchronized
    if dist.get_rank(parallel_context.pp_pg) == 0:
        weight = model.dense0.weight
        bias = model.dense0.bias
    else:
        weight = model.dense1.weight
        bias = model.dense1.bias

    # Make sure that weight/bias are NanotronParameter and that they are tied
    assert isinstance(weight, NanotronParameter)
    assert weight.is_tied
    assert isinstance(bias, NanotronParameter)
    assert bias.is_tied

    # Weights/bias are not synced yet
    assert not assert_tensor_equal_over_group(weight, group=group, assert_=False)
    assert not assert_tensor_equal_over_group(bias, group=group, assert_=False)

    # Compute gradient
    input_ = torch.randn(13, 10, device="cuda")
    if dist.get_rank(parallel_context.pp_pg) == 0:
        out = model.dense0(input_)
    else:
        out = model.dense1(input_)
    out.sum().backward()

    # sync gradients
    # TODO @thomasw21: This should be done in hooks
    sync_tied_weights_gradients(model, parallel_context=parallel_context, grad_accumulator=None)

    # Check that we have gradient
    assert weight.grad is not None
    assert bias.grad is not None

    # We check that we both gradients are synchronized
    assert_tensor_equal_over_group(weight.grad, group=group)
    assert_tensor_equal_over_group(bias.grad, group=group)

    parallel_context.destroy()