group_sharding.py 6.95 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#
# See LICENSE for license information.
"""Unittest for group sharding"""

import unittest

import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import (
    DygraphShardingOptimizer,)

from utils import assert_allclose, set_random_seed
import transformer_engine.paddle as te


class TestGroupSharding(unittest.TestCase):
    """Tests group sharding"""

    def setUp(self):
        self.set_attr()
        self.init_dist_env()
        paddle.set_default_dtype(self.global_dtype)

    def set_attr(self):
        """Set test configs"""
        self.sharding_degree = 2
        self.global_dtype = 'float32'
        self.rtol = 1e-5
        self.atol = 1e-5
        self.batch_size = 16
        self.in_channels = 16
        self.out_channels = 32
        self.fp8 = False

    def init_dist_env(self):
        """Init Paddle Fleet environment"""
        strategy = fleet.DistributedStrategy()
        strategy.hybrid_configs = {
            "dp_degree": 1,
            "mp_degree": 1,
            "pp_degree": 1,
            "sharding_degree": self.sharding_degree,
        }
        self.strategy = strategy
        fleet.init(is_collective=True, strategy=strategy)

    def _get_model_and_optimizer(self, model, stage):
        if stage == 1:
            optimizer = DygraphShardingOptimizer(
51
52
                paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters()),
                fleet.get_hybrid_communicate_group(),
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
            )
            model = fleet.distributed_model(model)
            optimizer = fleet.distributed_optimizer(optimizer)
        elif stage in [2, 3]:
            optimizer = paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters())
            group = fleet.get_hybrid_communicate_group().get_sharding_parallel_group()

            class ShardingLevel:    # pylint: disable=too-few-public-methods,
                """Paddle sharding options"""
                kStage1 = 'os'
                kStage2 = 'os_g'
                kStage3 = 'p_g_os'

            level = ShardingLevel.kStage3 if stage == 3 else ShardingLevel.kStage2
            model, optimizer, _ = paddle.distributed.sharding.group_sharded_parallel(
                model=model,
                optimizer=optimizer,
                level=level,
                group=group,
                segment_size=256,
            )
        else:
            raise ValueError(f"Stage {stage} not supported")
        return model, optimizer

    def test_group_sharding_stage1(self):
        """Tests group sharding training"""
        set_random_seed(1024)
        model_te = te.Linear(self.in_channels, self.out_channels)
        model_pd = paddle.nn.Linear(self.in_channels, self.out_channels)
        model_pd.weight.copy_(model_te.weight.T, True)
        model_pd.bias.copy_(model_te.bias, True)

        model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=1)
        model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=1)

        rank_id = paddle.distributed.get_rank()
        paddle.seed(rank_id)

        def train_one_step(model, inp, optimizer):
            out = model(inp)
            loss = out.mean()
            loss.backward()
            optimizer.step()
            optimizer.clear_grad()
            return loss

        for _ in range(5):
            inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype)
            with te.fp8_autocast(enabled=False):
                loss_te = train_one_step(model_te, inp, optimizer_te)
            loss_pd = train_one_step(model_pd, inp, optimizer_pd)
            assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)

        assert len(optimizer_te.state_dict()) == 4, \
            "Expect each rank to hold 4 optimizer state entries."

    def test_group_sharding_stage2(self):
        """Tests group sharding training"""
        set_random_seed(1024)
        model_te = te.Linear(self.in_channels, self.out_channels)
        model_pd = paddle.nn.Linear(self.in_channels, self.out_channels)
        model_pd.weight.copy_(model_te.weight.T, True)
        model_pd.bias.copy_(model_te.bias, True)

        model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=2)
        model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=2)

        rank_id = paddle.distributed.get_rank()
        paddle.seed(rank_id)

        def train_one_step(model, inp, optimizer):
            out = model(inp)
            loss = out.mean()
            loss.backward()
            # Check gradients are split to different trainers
            if rank_id == 0:
                assert model.bias.grad is None and model.weight.grad is not None
            elif rank_id == 1:
                assert model.weight.grad is None and model.bias.grad is not None
            optimizer.step()
            optimizer.clear_grad()
            return loss

        for _ in range(5):
            inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype)
            with te.fp8_autocast(enabled=False):
                loss_te = train_one_step(model_te, inp, optimizer_te)
            loss_pd = train_one_step(model_pd, inp, optimizer_pd)
            assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)

        assert len(optimizer_te.state_dict()) == 4, \
            "Expect each rank to hold 4 optimizer state entries."

    def test_group_sharding_stage3(self):
        """Tests group sharding training"""
        set_random_seed(1024)
        model_te = te.Linear(self.in_channels, self.out_channels)
        model_pd = paddle.nn.Linear(self.in_channels, self.out_channels)
        model_pd.weight.copy_(model_te.weight.T, True)
        model_pd.bias.copy_(model_te.bias, True)

        model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=3)
        model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=3)

        rank_id = paddle.distributed.get_rank()
        paddle.seed(rank_id)

        def train_one_step(model, inp, optimizer):
            out = model(inp)
            loss = out.mean()
            loss.backward()
            optimizer.step()
            optimizer.clear_grad()
            return loss

        for _ in range(5):
            inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype)
            with te.fp8_autocast(enabled=False):
                loss_te = train_one_step(model_te, inp, optimizer_te)
            loss_pd = train_one_step(model_pd, inp, optimizer_pd)
            assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)

        for name, value in optimizer_te.state_dict().items():
            if name.endswith('w_0_moment1_0'):
                assert value.numel() == \
                    self.in_channels * self.out_channels // self.sharding_degree, \
                    "Expect optimizer state to be sharded across trainers."


if __name__ == '__main__':
    unittest.main()