transformer_tp.py 9.15 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
#
# See LICENSE for license information.
"""Unittest for Transformer layer in tensor parallel"""

import unittest

import paddle
from paddle.distributed import fleet
10
from paddle.distributed.fleet.layers.mpu import mp_ops
11

12
from utils import assert_allclose, set_random_seed, register_sequence_parallel_allreduce_hooks
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import transformer_engine.paddle as te


class TestTransformerTp(unittest.TestCase):
    """Tests Transformer layer with model parallel in BF16"""

    def setUp(self):
        self.set_attr()
        self.init_dist_env()
        paddle.set_default_dtype(self.global_dtype)

    def init_dist_env(self):
        """Init Paddle Fleet environment"""
        strategy = fleet.DistributedStrategy()
        self.model_parallel_size = 2
        strategy.hybrid_configs = {
            "dp_degree": 1,
            "mp_degree": self.model_parallel_size,
            "pp_degree": 1,
        }
33
        strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
34
        fleet.init(is_collective=True, strategy=strategy)
35
        self.rank = fleet.worker_index()
36
37
        self.hcg = fleet.get_hybrid_communicate_group()
        self.tp_group = self.hcg.get_model_parallel_group()
38
        self.world_size = self.hcg.get_model_parallel_world_size()
39
40
41
42
43
44
45
46
47

    def set_attr(self):
        """Set test configs"""
        self.batch_size = 16
        self.hidden_size = 1024
        self.num_heads = 16
        self.ffn_hidden_size = 4096
        self.q_seqlen = 128
        self.kv_seqlen = 128
48
49
50
        self.mask_type = "padding"
        self.layer_type = "encoder"
        self.global_dtype = "bfloat16"
51
52
53
54
        self.rtol = 5e-2
        self.atol = 5e-2
        self.eps = 1e-3
        self.fp8 = False
55
56
57
58
59
60
        self.sequence_parallel = False

    def _train_one_step(self, layer, inp_list, optimizer, fp8_enabled, sequence_parallel=False):
        inp, mask = inp_list
        if sequence_parallel:
            split_size = inp.shape[0] // self.world_size
61
            input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        else:
            input_parallel = inp
        with te.fp8_autocast(enabled=fp8_enabled):
            out = layer(input_parallel, mask)
        if sequence_parallel:
            total_out = mp_ops._c_concat(out, group=self.tp_group)
            total_out = paddle.concat(paddle.split(total_out, self.world_size, axis=-1), axis=0)
        else:
            total_out = out
        loss = total_out.mean()
        loss.backward()
        optimizer.step()
        optimizer.clear_grad()
        return loss, total_out
76
77
78
79
80
81
82
83
84
85

    def test_parallel_layer(self):
        """Tests parallel Transformer"""
        set_random_seed(1024)
        common_args = [
            self.hidden_size,
            self.ffn_hidden_size,
            self.num_heads,
        ]
        common_kwargs = {
86
87
88
89
90
            "layernorm_epsilon": self.eps,
            "hidden_dropout": 0.0,
            "attention_dropout": 0.0,
            "self_attn_mask_type": self.mask_type,
            "layer_type": self.layer_type,
91
        }
92
93
94
95
96
97
        layer_tp = te.TransformerLayer(
            *common_args,
            **common_kwargs,
            set_parallel_mode=True,
            sequence_parallel=self.sequence_parallel,
        )
98
99
        layer_single = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=False)

100
        def _get_total_weight(local_weight, tp_group, axis, interleave=False):
101
102
103
            total_weight = []
            partial_weight = local_weight.clone().detach()
            paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group)
104
105
            if interleave:
                # Due to the interleaved qkv layout, need to concat on num_head
hugo-syn's avatar
hugo-syn committed
106
                # dimension for column parallel linear in MultiHeadAttention layer
107
                assert axis == 0
108
109
110
111
                assert [
                    3 * self.hidden_size // self.world_size,
                    self.hidden_size,
                ] == partial_weight.shape
112
113
114
                local_num_head = self.num_heads // self.world_size
                for idx, _ in enumerate(total_weight):
                    total_weight[idx] = total_weight[idx].reshape(
115
116
                        [3, local_num_head, -1, self.hidden_size]
                    )
117
118
119
                total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size])
            else:
                total_weight = paddle.concat(total_weight, axis=axis)
120
121
122
123
124
125
126
            return total_weight

        def _get_weight(obj, weight_names):
            for name in weight_names:
                obj = getattr(obj, name)
            return obj

127
        def copy_weight(layer_src, layer_dst, partition_mode, weight_names, interleave=False):
128
129
130
131
            weight_src = _get_weight(layer_src, weight_names)
            weight_dst = _get_weight(layer_dst, weight_names)
            if partition_mode is None:
                total_weight = weight_src
132
133
134
135
136
            elif partition_mode == "column":
                total_weight = _get_total_weight(
                    weight_src, tp_group=self.tp_group, axis=0, interleave=interleave
                )
            elif partition_mode == "row":
137
138
139
                total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1)
            else:
                raise ValueError(f"Partition Mode {partition_mode} is not supported.")
140
141
142
            assert (
                weight_dst.shape == total_weight.shape
            ), f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match."
143
144
            weight_dst.copy_(total_weight, True)

145
146
147
148
149
150
151
152
153
154
155
156
        copy_weight(layer_tp, layer_single, None, ["self_attention", "layernorm_qkv", "ln_weight"])
        copy_weight(
            layer_tp,
            layer_single,
            "column",
            ["self_attention", "layernorm_qkv", "weight"],
            interleave=True,
        )
        copy_weight(layer_tp, layer_single, "row", ["self_attention", "proj", "weight"])
        copy_weight(layer_tp, layer_single, None, ["layernorm_mlp", "ln_weight"])
        copy_weight(layer_tp, layer_single, "column", ["layernorm_mlp", "fc1_weight"])
        copy_weight(layer_tp, layer_single, "row", ["layernorm_mlp", "fc2_weight"])
157

158
159
160
161
        if self.sequence_parallel:
            register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1)

        optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters())
162
163
164
        optimizer_single = paddle.optimizer.SGD(
            learning_rate=0.01, parameters=layer_single.parameters()
        )
165
166
167
168
169

        layer_tp = fleet.distributed_model(layer_tp)
        optimizer_tp = fleet.distributed_optimizer(optimizer_tp)

        for _ in range(5):
170
171
172
173
174
175
176
177
178
179
180
181
            inp = paddle.uniform(
                [self.batch_size, self.q_seqlen, self.hidden_size], self.global_dtype
            )
            mask = paddle.zeros(
                shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), dtype="bool"
            )
            loss_tp, out_tp = self._train_one_step(
                layer_tp, [inp, mask], optimizer_tp, self.fp8, self.sequence_parallel
            )
            loss_single, out_single = self._train_one_step(
                layer_single, [inp, mask], optimizer_single, self.fp8
            )
182
            assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol)
183
184
185
186
187
188
            assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol)


class TestTransformerTpFp8(TestTransformerTp):
    """Tests Transformer layer with tensor parallelism in FP8"""

189
190
191
192
193
194
195
196
    def set_attr(self):
        """Set test configs"""
        self.batch_size = 16
        self.hidden_size = 1024
        self.num_heads = 16
        self.ffn_hidden_size = 4096
        self.q_seqlen = 128
        self.kv_seqlen = 128
197
198
199
        self.mask_type = "padding"
        self.layer_type = "encoder"
        self.global_dtype = "bfloat16"
200
201
202
203
204
205
206
207
208
209
        self.rtol = 5e-2
        self.atol = 0.5
        self.eps = 1e-3
        self.fp8 = True
        self.sequence_parallel = False


class TestTransformerSp(TestTransformerTp):
    """Tests Transformer layer with sequence parallel in BF16"""

210
211
212
213
214
215
216
217
    def set_attr(self):
        """Set test configs"""
        self.batch_size = 16
        self.hidden_size = 1024
        self.num_heads = 16
        self.ffn_hidden_size = 4096
        self.q_seqlen = 128
        self.kv_seqlen = 128
218
219
220
        self.mask_type = "padding"
        self.layer_type = "encoder"
        self.global_dtype = "bfloat16"
221
222
223
        self.rtol = 5e-2
        self.atol = 5e-2
        self.eps = 1e-3
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
        self.fp8 = False
        self.sequence_parallel = True


class TestTransformerSpFp8(TestTransformerSp):
    """Tests Transformer layer with sequence parallelism in FP8"""

    def set_attr(self):
        """Set test configs"""
        self.batch_size = 16
        self.hidden_size = 1024
        self.num_heads = 16
        self.ffn_hidden_size = 4096
        self.q_seqlen = 128
        self.kv_seqlen = 128
239
240
241
        self.mask_type = "padding"
        self.layer_type = "encoder"
        self.global_dtype = "bfloat16"
242
243
244
        self.rtol = 5e-2
        self.atol = 0.5
        self.eps = 1e-3
245
        self.fp8 = True
246
        self.sequence_parallel = True
247
248


249
if __name__ == "__main__":
250
    unittest.main()