linear_tp.py 8.07 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#
# See LICENSE for license information.
"""Unittest for Linear layer in tensor parallel"""

import unittest

import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops

from utils import assert_allclose, assert_shape, set_random_seed
import transformer_engine.paddle as te


class TestLinearTp(unittest.TestCase):
    """Tests Linear layer with column/row parallelism in BF16"""

    def setUp(self):
        self.set_attr()
        self.init_dist_env()
        paddle.set_default_dtype(self.global_dtype)

    def init_dist_env(self):
        """Init Paddle Fleet environment"""
        strategy = fleet.DistributedStrategy()
        self.model_parallel_size = 2
        strategy.hybrid_configs = {
            "dp_degree": 1,
            "mp_degree": self.model_parallel_size,
            "pp_degree": 1,
        }
33
        strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
        fleet.init(is_collective=True, strategy=strategy)
        self.rank = fleet.worker_index()
        self.hcg = fleet.get_hybrid_communicate_group()
        self.tp_group = self.hcg.get_model_parallel_group()
        self.world_size = self.hcg.get_model_parallel_world_size()

    def set_attr(self):
        """Set test configs"""
        self.batch_size = 16
        self.in_features = 32
        self.out_features = 64
        self.global_dtype = 'bfloat16'
        self.rtol = 1e-3
        self.atol = 1e-3
        self.fp8 = False
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        self.sequence_parallel = False

    def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False):
        inp = paddle.to_tensor(inp, stop_gradient=True)
        assert split_input in ['none', 'column', 'row']
        if split_input == 'column':
            split_size = inp.shape[1] // self.world_size
            input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)]
        elif split_input == 'row':
            split_size = inp.shape[0] // self.world_size
            input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
        else:
            input_parallel = inp
        input_parallel.stop_gradient = False
        out = layer(input_parallel)
        if gather_output:
            total_out = mp_ops._c_concat(out, group=self.tp_group)
        else:
            total_out = out
        loss = total_out.mean()
        loss.backward()
        optimizer.step()
        optimizer.clear_grad()
        if split_input != 'none':
            grad_input = []
            paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
            if split_input == 'column':
                grad_input = paddle.concat(grad_input, axis=1)
            elif split_input == 'row':
                grad_input = paddle.concat(grad_input, axis=0)
        else:
            grad_input = input_parallel.grad
        return loss, grad_input
82
83
84
85
86
87
88
89

    def test_column_parallel_layer(self):
        """Tests column parallel linear"""
        set_random_seed(1024)
        layer_te = te.Linear(
            self.in_features,
            self.out_features,
            parallel_mode='column',
90
            sequence_parallel=self.sequence_parallel,
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        )
        layer_pd = te.Linear(
            self.in_features,
            self.out_features,
            backend='paddle',
        )
        # Get total weight
        total_weight = []
        partial_weight = layer_te.weight.clone().detach()
        paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group)
        total_weight = paddle.concat(total_weight, axis=0)
        layer_pd.weight.copy_(total_weight.T, True)

        assert_shape(layer_te.weight,
                     [self.out_features // self.model_parallel_size, self.in_features])
        assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size])

        optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
        optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters())

        layer_te = fleet.distributed_model(layer_te)
        optimizer_te = fleet.distributed_optimizer(optimizer_te)

        for _ in range(5):
            inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
            with te.fp8_autocast(enabled=self.fp8):
117
118
119
120
121
122
123
                loss_tp, grad_input = self._train_one_step(
                    layer_te,
                    inp,
                    optimizer_te,
                    split_input='row' if self.sequence_parallel else 'none',
                    gather_output=True)
            loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
124
125
126
127
128
129
130
131
132
133
            assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
            assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)

    def test_row_parallel_layer(self):
        """Tests row parallel linear"""
        set_random_seed(1024)
        layer_te = te.Linear(
            self.in_features,
            self.out_features,
            parallel_mode='row',
134
            sequence_parallel=self.sequence_parallel,
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        )
        layer_pd = te.Linear(
            self.in_features,
            self.out_features,
            backend='paddle',
        )
        # Get total weight
        total_weight = []
        partial_weight = layer_te.weight.clone().detach()
        paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group)
        total_weight = paddle.concat(total_weight, axis=1)
        layer_pd.weight.copy_(total_weight.T, True)

        assert_shape(layer_te.weight,
                     [self.out_features, self.in_features // self.model_parallel_size])
        assert_shape(layer_te.bias, [self.out_features])

        optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
        optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters())

155
156
        layer_te = fleet.distributed_model(layer_te)
        optimizer_te = fleet.distributed_optimizer(optimizer_te)
157
158
159
160

        for _ in range(5):
            inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
            with te.fp8_autocast(enabled=self.fp8):
161
162
163
164
165
166
                loss_tp, grad_input = self._train_one_step(layer_te,
                                                           inp,
                                                           optimizer_te,
                                                           split_input='column',
                                                           gather_output=self.sequence_parallel)
            loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
            assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
            assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)


class TestLinearTpFP8(TestLinearTp):
    """Tests Linear layer with column/row parallelism in FP8"""

    def set_attr(self):
        """Set test configs"""
        self.batch_size = 16
        self.in_features = 32
        self.out_features = 64
        self.global_dtype = 'bfloat16'
        self.rtol = 1e-2
        self.atol = 1e-2
        self.fp8 = True
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        self.sequence_parallel = False


class TestLinearSp(TestLinearTp):
    """Tests Linear layer with sequence parallelism"""

    def set_attr(self):
        """Set test configs"""
        self.batch_size = 16
        self.in_features = 32
        self.out_features = 64
        self.global_dtype = 'bfloat16'
        self.rtol = 1e-3
        self.atol = 1e-3
        self.fp8 = False
        self.sequence_parallel = True


class TestLinearSpFP8(TestLinearTp):
    """Tests Linear layer with sequence parallelism in FP8"""

    def set_attr(self):
        """Set test configs"""
        self.batch_size = 16
        self.in_features = 32
        self.out_features = 64
        self.global_dtype = 'bfloat16'
        self.rtol = 1e-2
        self.atol = 1e-2
        self.fp8 = True
        self.sequence_parallel = True
214
215
216
217


if __name__ == '__main__':
    unittest.main()