linear_pp.py 7.83 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#
# See LICENSE for license information.
"""Unittest for Linear layer in pipeline parallel"""

import unittest

import numpy as np

import paddle
from paddle.distributed import fleet

from paddle.distributed.fleet.meta_parallel import (
    LayerDesc,
    PipelineLayer,
)

from utils import assert_allclose, set_random_seed
import transformer_engine.paddle as te


22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class TELinear(te.Linear):
    """To pass is_first_microbatch"""

    def __init__(self, *args, **kwargs):
        assert 'accumulate_steps' in kwargs
        self.accumulate_steps = kwargs['accumulate_steps']
        del kwargs['accumulate_steps']
        self._micro_batch_id = 0
        super().__init__(*args, **kwargs)

    def forward(self, *args, **kwargs):
        kwargs['is_first_microbatch'] = (self._micro_batch_id % self.accumulate_steps) == 0
        if paddle.is_grad_enabled() and self.training:
            self._micro_batch_id += 1
        return super().forward(*args, **kwargs)


39
40
41
42
43
44
45
46
47
class TEPipelineModel(PipelineLayer):
    """Model for pipeline parallel test"""

    def __init__(self,
                 in_features,
                 hidden_features,
                 weight_attrs,
                 use_te=True,
                 use_fp8=False,
48
                 accumulate_steps=1,
49
50
51
52
53
54
55
                 **kwargs):
        self.in_features = in_features
        self.hidden_features = hidden_features
        self.fp8 = use_fp8
        hcg = fleet.get_hybrid_communicate_group()
        self.dp_group = hcg.get_data_parallel_group()

56
57
58
59
60
        Linear = TELinear if use_te else paddle.nn.Linear
        extra_kwargs = {}
        if use_te:
            extra_kwargs['accumulate_steps'] = accumulate_steps

61
        model_desc = [
62
63
64
65
66
67
68
69
70
71
            LayerDesc(Linear,
                      self.in_features,
                      self.hidden_features,
                      weight_attr=weight_attrs[0],
                      **extra_kwargs),
            LayerDesc(Linear,
                      self.hidden_features,
                      self.in_features,
                      weight_attr=weight_attrs[1],
                      **extra_kwargs),
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        ]
        super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs)

    def forward(self, *args, **kwargs):
        with te.fp8_autocast(enabled=self.fp8, fp8_group=self.dp_group):
            return super().forward(*args, **kwargs)


class StandaloneModel(paddle.nn.Layer):
    """Model for pipeline parallel test"""

    def __init__(self, in_features, hidden_features, weight_attrs):
        super().__init__()
        self.in_features = in_features
        self.hidden_features = hidden_features
        Linear = paddle.nn.Linear
        self.layer = paddle.nn.Sequential(
            Linear(self.in_features, self.hidden_features, weight_attr=weight_attrs[0]),
            Linear(self.hidden_features, self.in_features, weight_attr=weight_attrs[1]),
        )
        self.loss = paddle.nn.CrossEntropyLoss()

    def forward(self, inp):
        out = self.layer(inp[0])
        loss = self.loss(out, inp[1])
        return loss


class TestLinearPipelineParallel(unittest.TestCase):
    """Tests Linear layer with pipeline parallel"""

    def setUp(self):
        self.set_attr()
        self.init_dist_env()
        paddle.set_default_dtype(self.global_dtype)

    def init_dist_env(self):
        """Init Paddle Fleet environment"""
        strategy = fleet.DistributedStrategy()
        self.pipeline_parallel_size = 2
        strategy.hybrid_configs = {
            "dp_degree": 1,
            "mp_degree": 1,
            "pp_degree": self.pipeline_parallel_size,
        }
117
        self.accumulate_steps = self.batch_size // self.micro_batch_size
118
        strategy.pipeline_configs = {
119
            "accumulate_steps": self.accumulate_steps,
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
            "micro_batch_size": self.micro_batch_size,
        }
        fleet.init(is_collective=True, strategy=strategy)
        self.rank = fleet.worker_index()
        self.hcg = fleet.get_hybrid_communicate_group()

    def set_attr(self):
        """Set test configs"""
        self.batch_size = 32
        self.micro_batch_size = 16
        self.in_features = 32
        self.hidden_features = 64
        self.global_dtype = 'float32'
        self.rtol = 1e-5
        self.atol = 1e-5
        self.iter = 10
        self.fp8 = False

    def test_pipeline_train(self):
        """Test pipeline parallel training"""
        set_random_seed(1024)
141
        np.random.seed(1024)
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

        weight1_np = np.random.normal(size=[self.in_features, self.hidden_features])
        weight2_np = np.random.normal(size=[self.hidden_features, self.in_features])
        weight_attrs = [
            paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np)),
            paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np)),
        ]
        weight_attrs_transposed = [
            paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np.T)),
            paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np.T)),
        ]

        pipe_model = TEPipelineModel(
            self.in_features,
            self.hidden_features,
            weight_attrs_transposed,
            use_te=True,
            use_fp8=self.fp8,
            seg_method="layer:Linear",
            num_stages=self.pipeline_parallel_size,
162
            accumulate_steps=self.accumulate_steps,
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        )

        # Check if model is split across ranks as expected
        for name, sublayer in pipe_model.named_sublayers():
            if name in ('_loss_fn', 'shared_layers'):
                continue
            if self.rank == 0:
                assert tuple(sublayer.weight.shape) == weight1_np.T.shape, \
                    f"Shape does not match, expect: {weight1_np.T.shape}, " \
                    f"actual: {tuple(sublayer.weight.shape)}"
            elif self.rank == 1:
                assert tuple(sublayer.weight.shape) == weight2_np.T.shape, \
                    f"Shape does not match, expect: {weight2_np.T.shape}, " \
                    f"actual: {tuple(sublayer.weight.shape)}"

        standalone_model = StandaloneModel(
            self.in_features,
            self.hidden_features,
            weight_attrs,
        )

        optimizer_te = paddle.optimizer.SGD(learning_rate=0.1, parameters=pipe_model.parameters())
        optimizer_pd = paddle.optimizer.SGD(learning_rate=0.1,
                                            parameters=standalone_model.parameters())

        pipe_model = fleet.distributed_model(pipe_model)
        optimizer_te = fleet.distributed_optimizer(optimizer_te)

        def train_one_step(layer, inp, optimizer):
            loss = layer(inp)
            loss.backward()
            optimizer.step()
            optimizer.clear_grad()
            return loss

        for i in range(self.iter):
            inp = paddle.to_tensor(np.random.normal(size=[self.batch_size, self.in_features]),
                                   dtype=self.global_dtype)
            label = paddle.to_tensor(np.random.randint(self.in_features, size=[self.batch_size, 1]))
            loss_te = pipe_model.train_batch([inp, label], optimizer_te)
            loss_pd = train_one_step(standalone_model, [inp, label], optimizer_pd)
            print(f"Iter: {i}, loss_te: {loss_te.item()}, loss_pd: {loss_pd.item()}")
            assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)


class TestLinearPipelineParallelFP8(TestLinearPipelineParallel):
    """Tests Linear layer with column/row parallelism in FP8"""

    def set_attr(self):
        """Set test configs"""
        self.batch_size = 32
        self.micro_batch_size = 16
        self.in_features = 32
        self.hidden_features = 64
        self.global_dtype = 'float32'
        self.rtol = 5e-2
        self.atol = 5e-2
        self.iter = 10
        self.fp8 = True


if __name__ == '__main__':
    unittest.main()