linear_tp.py 7.97 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#
# See LICENSE for license information.
"""Unittest for Linear layer in tensor parallel"""

import unittest

import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops

from utils import assert_allclose, assert_shape, set_random_seed
import transformer_engine.paddle as te


class TestLinearTp(unittest.TestCase):
    """Tests Linear layer with column/row parallelism in BF16"""

    def setUp(self):
        self.set_attr()
        self.init_dist_env()
        paddle.set_default_dtype(self.global_dtype)

    def init_dist_env(self):
        """Init Paddle Fleet environment"""
        strategy = fleet.DistributedStrategy()
        self.model_parallel_size = 2
        strategy.hybrid_configs = {
            "dp_degree": 1,
            "mp_degree": self.model_parallel_size,
            "pp_degree": 1,
        }
33
        strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
34
35
36
37
38
39
40
41
42
43
44
        fleet.init(is_collective=True, strategy=strategy)
        self.rank = fleet.worker_index()
        self.hcg = fleet.get_hybrid_communicate_group()
        self.tp_group = self.hcg.get_model_parallel_group()
        self.world_size = self.hcg.get_model_parallel_world_size()

    def set_attr(self):
        """Set test configs"""
        self.batch_size = 16
        self.in_features = 32
        self.out_features = 64
45
        self.global_dtype = "bfloat16"
46
47
48
        self.rtol = 1e-3
        self.atol = 1e-3
        self.fp8 = False
49
50
        self.sequence_parallel = False

51
    def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False):
52
        inp = paddle.to_tensor(inp, stop_gradient=True)
53
54
        assert split_input in ["none", "column", "row"]
        if split_input == "column":
55
            split_size = inp.shape[1] // self.world_size
56
57
            input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)]
        elif split_input == "row":
58
            split_size = inp.shape[0] // self.world_size
59
            input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
60
61
62
63
64
65
66
67
68
69
70
71
        else:
            input_parallel = inp
        input_parallel.stop_gradient = False
        out = layer(input_parallel)
        if gather_output:
            total_out = mp_ops._c_concat(out, group=self.tp_group)
        else:
            total_out = out
        loss = total_out.mean()
        loss.backward()
        optimizer.step()
        optimizer.clear_grad()
72
        if split_input != "none":
73
74
            grad_input = []
            paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
75
            if split_input == "column":
76
                grad_input = paddle.concat(grad_input, axis=1)
77
            elif split_input == "row":
78
79
80
81
                grad_input = paddle.concat(grad_input, axis=0)
        else:
            grad_input = input_parallel.grad
        return loss, grad_input
82
83
84
85
86
87
88

    def test_column_parallel_layer(self):
        """Tests column parallel linear"""
        set_random_seed(1024)
        layer_te = te.Linear(
            self.in_features,
            self.out_features,
89
            parallel_mode="column",
90
            sequence_parallel=self.sequence_parallel,
91
92
93
94
        )
        layer_pd = te.Linear(
            self.in_features,
            self.out_features,
95
            backend="paddle",
96
97
98
99
100
101
102
103
        )
        # Get total weight
        total_weight = []
        partial_weight = layer_te.weight.clone().detach()
        paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group)
        total_weight = paddle.concat(total_weight, axis=0)
        layer_pd.weight.copy_(total_weight.T, True)

104
105
106
        assert_shape(
            layer_te.weight, [self.out_features // self.model_parallel_size, self.in_features]
        )
107
108
109
110
111
112
113
114
115
116
117
        assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size])

        optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
        optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters())

        layer_te = fleet.distributed_model(layer_te)
        optimizer_te = fleet.distributed_optimizer(optimizer_te)

        for _ in range(5):
            inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
            with te.fp8_autocast(enabled=self.fp8):
118
119
120
121
                loss_tp, grad_input = self._train_one_step(
                    layer_te,
                    inp,
                    optimizer_te,
122
123
124
                    split_input="row" if self.sequence_parallel else "none",
                    gather_output=True,
                )
125
            loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
126
127
128
129
130
131
132
133
134
            assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
            assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)

    def test_row_parallel_layer(self):
        """Tests row parallel linear"""
        set_random_seed(1024)
        layer_te = te.Linear(
            self.in_features,
            self.out_features,
135
            parallel_mode="row",
136
            sequence_parallel=self.sequence_parallel,
137
138
139
140
        )
        layer_pd = te.Linear(
            self.in_features,
            self.out_features,
141
            backend="paddle",
142
143
144
145
146
147
148
149
        )
        # Get total weight
        total_weight = []
        partial_weight = layer_te.weight.clone().detach()
        paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group)
        total_weight = paddle.concat(total_weight, axis=1)
        layer_pd.weight.copy_(total_weight.T, True)

150
151
152
        assert_shape(
            layer_te.weight, [self.out_features, self.in_features // self.model_parallel_size]
        )
153
154
155
156
157
        assert_shape(layer_te.bias, [self.out_features])

        optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
        optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters())

158
159
        layer_te = fleet.distributed_model(layer_te)
        optimizer_te = fleet.distributed_optimizer(optimizer_te)
160
161
162
163

        for _ in range(5):
            inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
            with te.fp8_autocast(enabled=self.fp8):
164
165
166
167
168
169
170
                loss_tp, grad_input = self._train_one_step(
                    layer_te,
                    inp,
                    optimizer_te,
                    split_input="column",
                    gather_output=self.sequence_parallel,
                )
171
            loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
172
173
174
175
176
177
178
179
180
181
182
183
            assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
            assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)


class TestLinearTpFP8(TestLinearTp):
    """Tests Linear layer with column/row parallelism in FP8"""

    def set_attr(self):
        """Set test configs"""
        self.batch_size = 16
        self.in_features = 32
        self.out_features = 64
184
        self.global_dtype = "bfloat16"
185
186
187
        self.rtol = 1e-2
        self.atol = 1e-2
        self.fp8 = True
188
189
190
191
192
193
194
195
196
197
198
        self.sequence_parallel = False


class TestLinearSp(TestLinearTp):
    """Tests Linear layer with sequence parallelism"""

    def set_attr(self):
        """Set test configs"""
        self.batch_size = 16
        self.in_features = 32
        self.out_features = 64
199
        self.global_dtype = "bfloat16"
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        self.rtol = 1e-3
        self.atol = 1e-3
        self.fp8 = False
        self.sequence_parallel = True


class TestLinearSpFP8(TestLinearTp):
    """Tests Linear layer with sequence parallelism in FP8"""

    def set_attr(self):
        """Set test configs"""
        self.batch_size = 16
        self.in_features = 32
        self.out_features = 64
214
        self.global_dtype = "bfloat16"
215
216
217
218
        self.rtol = 1e-2
        self.atol = 1e-2
        self.fp8 = True
        self.sequence_parallel = True
219
220


221
if __name__ == "__main__":
222
    unittest.main()