linear_pp.py 7.81 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#
# See LICENSE for license information.
"""Unittest for Linear layer in pipeline parallel"""

import unittest

import numpy as np

import paddle
from paddle.distributed import fleet

from paddle.distributed.fleet.meta_parallel import (
    LayerDesc,
    PipelineLayer,
)

from utils import assert_allclose, set_random_seed
import transformer_engine.paddle as te


22
23
24
25
class TELinear(te.Linear):
    """To pass is_first_microbatch"""

    def __init__(self, *args, **kwargs):
26
27
28
        assert "accumulate_steps" in kwargs
        self.accumulate_steps = kwargs["accumulate_steps"]
        del kwargs["accumulate_steps"]
29
30
31
32
        self._micro_batch_id = 0
        super().__init__(*args, **kwargs)

    def forward(self, *args, **kwargs):
33
        kwargs["is_first_microbatch"] = (self._micro_batch_id % self.accumulate_steps) == 0
34
35
36
37
38
        if paddle.is_grad_enabled() and self.training:
            self._micro_batch_id += 1
        return super().forward(*args, **kwargs)


39
40
41
class TEPipelineModel(PipelineLayer):
    """Model for pipeline parallel test"""

42
43
44
45
46
47
48
49
50
51
    def __init__(
        self,
        in_features,
        hidden_features,
        weight_attrs,
        use_te=True,
        use_fp8=False,
        accumulate_steps=1,
        **kwargs,
    ):
52
53
54
55
56
57
        self.in_features = in_features
        self.hidden_features = hidden_features
        self.fp8 = use_fp8
        hcg = fleet.get_hybrid_communicate_group()
        self.dp_group = hcg.get_data_parallel_group()

58
59
60
        Linear = TELinear if use_te else paddle.nn.Linear
        extra_kwargs = {}
        if use_te:
61
            extra_kwargs["accumulate_steps"] = accumulate_steps
62

63
        model_desc = [
64
65
66
67
68
69
70
71
72
73
74
75
76
77
            LayerDesc(
                Linear,
                self.in_features,
                self.hidden_features,
                weight_attr=weight_attrs[0],
                **extra_kwargs,
            ),
            LayerDesc(
                Linear,
                self.hidden_features,
                self.in_features,
                weight_attr=weight_attrs[1],
                **extra_kwargs,
            ),
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        ]
        super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs)

    def forward(self, *args, **kwargs):
        with te.fp8_autocast(enabled=self.fp8, fp8_group=self.dp_group):
            return super().forward(*args, **kwargs)


class StandaloneModel(paddle.nn.Layer):
    """Model for pipeline parallel test"""

    def __init__(self, in_features, hidden_features, weight_attrs):
        super().__init__()
        self.in_features = in_features
        self.hidden_features = hidden_features
        Linear = paddle.nn.Linear
        self.layer = paddle.nn.Sequential(
            Linear(self.in_features, self.hidden_features, weight_attr=weight_attrs[0]),
            Linear(self.hidden_features, self.in_features, weight_attr=weight_attrs[1]),
        )
        self.loss = paddle.nn.CrossEntropyLoss()

    def forward(self, inp):
        out = self.layer(inp[0])
        loss = self.loss(out, inp[1])
        return loss


class TestLinearPipelineParallel(unittest.TestCase):
    """Tests Linear layer with pipeline parallel"""

    def setUp(self):
        self.set_attr()
        self.init_dist_env()
        paddle.set_default_dtype(self.global_dtype)

    def init_dist_env(self):
        """Init Paddle Fleet environment"""
        strategy = fleet.DistributedStrategy()
        self.pipeline_parallel_size = 2
        strategy.hybrid_configs = {
            "dp_degree": 1,
            "mp_degree": 1,
            "pp_degree": self.pipeline_parallel_size,
        }
123
        self.accumulate_steps = self.batch_size // self.micro_batch_size
124
        strategy.pipeline_configs = {
125
            "accumulate_steps": self.accumulate_steps,
126
127
128
129
130
131
132
133
134
135
136
137
            "micro_batch_size": self.micro_batch_size,
        }
        fleet.init(is_collective=True, strategy=strategy)
        self.rank = fleet.worker_index()
        self.hcg = fleet.get_hybrid_communicate_group()

    def set_attr(self):
        """Set test configs"""
        self.batch_size = 32
        self.micro_batch_size = 16
        self.in_features = 32
        self.hidden_features = 64
138
        self.global_dtype = "float32"
139
140
141
142
143
144
145
146
        self.rtol = 1e-5
        self.atol = 1e-5
        self.iter = 10
        self.fp8 = False

    def test_pipeline_train(self):
        """Test pipeline parallel training"""
        set_random_seed(1024)
147
        np.random.seed(1024)
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167

        weight1_np = np.random.normal(size=[self.in_features, self.hidden_features])
        weight2_np = np.random.normal(size=[self.hidden_features, self.in_features])
        weight_attrs = [
            paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np)),
            paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np)),
        ]
        weight_attrs_transposed = [
            paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np.T)),
            paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np.T)),
        ]

        pipe_model = TEPipelineModel(
            self.in_features,
            self.hidden_features,
            weight_attrs_transposed,
            use_te=True,
            use_fp8=self.fp8,
            seg_method="layer:Linear",
            num_stages=self.pipeline_parallel_size,
168
            accumulate_steps=self.accumulate_steps,
169
170
171
172
        )

        # Check if model is split across ranks as expected
        for name, sublayer in pipe_model.named_sublayers():
173
            if name in ("_loss_fn", "shared_layers"):
174
175
                continue
            if self.rank == 0:
176
177
                assert tuple(sublayer.weight.shape) == weight1_np.T.shape, (
                    f"Shape does not match, expect: {weight1_np.T.shape}, "
178
                    f"actual: {tuple(sublayer.weight.shape)}"
179
                )
180
            elif self.rank == 1:
181
182
                assert tuple(sublayer.weight.shape) == weight2_np.T.shape, (
                    f"Shape does not match, expect: {weight2_np.T.shape}, "
183
                    f"actual: {tuple(sublayer.weight.shape)}"
184
                )
185
186
187
188
189
190
191
192

        standalone_model = StandaloneModel(
            self.in_features,
            self.hidden_features,
            weight_attrs,
        )

        optimizer_te = paddle.optimizer.SGD(learning_rate=0.1, parameters=pipe_model.parameters())
193
194
195
        optimizer_pd = paddle.optimizer.SGD(
            learning_rate=0.1, parameters=standalone_model.parameters()
        )
196
197
198
199
200
201
202
203
204
205
206
207

        pipe_model = fleet.distributed_model(pipe_model)
        optimizer_te = fleet.distributed_optimizer(optimizer_te)

        def train_one_step(layer, inp, optimizer):
            loss = layer(inp)
            loss.backward()
            optimizer.step()
            optimizer.clear_grad()
            return loss

        for i in range(self.iter):
208
209
210
            inp = paddle.to_tensor(
                np.random.normal(size=[self.batch_size, self.in_features]), dtype=self.global_dtype
            )
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
            label = paddle.to_tensor(np.random.randint(self.in_features, size=[self.batch_size, 1]))
            loss_te = pipe_model.train_batch([inp, label], optimizer_te)
            loss_pd = train_one_step(standalone_model, [inp, label], optimizer_pd)
            print(f"Iter: {i}, loss_te: {loss_te.item()}, loss_pd: {loss_pd.item()}")
            assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)


class TestLinearPipelineParallelFP8(TestLinearPipelineParallel):
    """Tests Linear layer with column/row parallelism in FP8"""

    def set_attr(self):
        """Set test configs"""
        self.batch_size = 32
        self.micro_batch_size = 16
        self.in_features = 32
        self.hidden_features = 64
227
        self.global_dtype = "float32"
228
229
230
231
232
233
        self.rtol = 5e-2
        self.atol = 5e-2
        self.iter = 10
        self.fp8 = True


234
if __name__ == "__main__":
235
    unittest.main()