_utils.py 9.5 KB
Newer Older
1
import copy
2
from contextlib import nullcontext
3
from typing import Any, Callable, Dict, List, Optional
4

5
import torch
6
import torch.distributed as dist
7
8
9
from torch import Tensor
from torch import distributed as dist
from torch.distributed import ProcessGroup
10
from torch.nn import Module
11
from torch.optim import Adam, Optimizer
12

13
14
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
15
from colossalai.lazy import LazyInitContext
16
from colossalai.pipeline.stage_manager import PipelineStageManager
17
from colossalai.shardformer import ShardConfig, ShardFormer
18
19
from colossalai.shardformer._utils import getattr_
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
20
21


22
23
24
25
26
27
28
29
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False):
    ctx = LazyInitContext() if use_lazy_init else nullcontext()
    with ctx:
        # create new model
        org_model = model_fn()
        model_copy = copy.deepcopy(org_model)
    if use_lazy_init:
        ctx.materialize(org_model)
30
    # shard model
31
32
    shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
                               enable_tensor_parallelism=enable_tensor_parallelism)
33
    shard_former = ShardFormer(shard_config=shard_config)
ver217's avatar
ver217 committed
34
    sharded_model, shared_params = shard_former.optimize(model_copy)
35
    return org_model.cuda(), sharded_model.cuda()
36
37


38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def build_pipeline_model(model_fn,
                         stage_manager=None,
                         enable_fused_normalization=False,
                         enable_tensor_parallelism=False,
                         use_lazy_init: bool = False):
    ctx = LazyInitContext() if use_lazy_init else nullcontext()
    with ctx:
        # create new model
        org_model = model_fn()
        model_copy = copy.deepcopy(org_model)
    if use_lazy_init:
        ctx.materialize(org_model)

    # shard model
    shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
                               enable_tensor_parallelism=enable_tensor_parallelism,
                               pipeline_stage_manager=stage_manager)
Jianghai's avatar
Jianghai committed
55

56
57
58
59
60
    shard_former = ShardFormer(shard_config=shard_config)
    sharded_model, shared_params = shard_former.optimize(model_copy)
    return org_model.cuda(), sharded_model.cuda()


61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
    # prepare input
    data = data_gen_fn()
    data = {k: v.cuda() for k, v in data.items()}
    # switch to train mode
    original_model.train()
    sharded_model.train()
    # run forward
    org_output = original_model(**data)
    org_output = output_transform_fn(org_output)
    org_loss = loss_fn(org_output)

    shard_output = sharded_model(**data)
    shard_output = output_transform_fn(shard_output)
    shard_loss = loss_fn(shard_output)
76
    return org_output, org_loss, shard_output, shard_loss
77
78
79
80
81
82
83
84
85
86
87


def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''):
    org_sd = org_model.state_dict()
    shard_sd = sharded_model.state_dict()
    for k, v in org_sd.items():
        assert k in shard_sd, f'{name} {k} not in sharded model'
        shard_v = shard_sd[k]
        assert v.shape == shard_v.shape, f'{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}'
        assert v.dtype == shard_v.dtype, f'{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}'
        assert torch.equal(v, shard_v), f'{name} {k} value mismatch'
88
89


90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any]):

    use_lazy_init = False
    if 'use_lazy_init' in test_config:
        use_lazy_init = test_config.pop('use_lazy_init')

    if use_lazy_init:
        ctx = LazyInitContext()
    else:
        ctx = nullcontext()

    plugin = HybridParallelPlugin(**test_config)
    booster = Booster(plugin=plugin)

    with ctx:
        org_model = model_fn().cuda()
        sharded_model = copy.deepcopy(org_model)

    if use_lazy_init:
        org_model = ctx.materialize(org_model)

    org_optimizer = Adam(org_model.parameters(), lr=1e-3)
    sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
    criterion = loss_fn

    sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)

    return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster


def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Module, sharded_optimizer: Optimizer,
                                            data_gen_fn: Callable, output_transform_fn: Callable, criterion: Callable,
                                            booster: Booster):

    def _criterion(outputs, inputs):
        outputs = output_transform_fn(outputs)
        loss = criterion(outputs)
        return loss

    data = data_gen_fn()
    sharded_model.train()
    if booster.plugin.stage_manager is not None:
        data = {
            k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
            for k, v in data.items()
        }
        data_iter = iter([data])
        sharded_output = booster.execute_pipeline(data_iter,
                                                  sharded_model,
                                                  _criterion,
                                                  sharded_optimizer,
                                                  return_loss=True,
                                                  return_outputs=True)
        sharded_loss = sharded_output['loss']
    else:
        data = {k: v.cuda() for k, v in data.items()}
        sharded_output = sharded_model(**data)
        sharded_loss = criterion(sharded_output)
        sharded_loss.backward()

    org_model.train()
    org_output = org_model(**data)
    org_loss = criterion(org_output)
    org_loss.backward()

    return org_loss, org_output, sharded_loss, sharded_output


def check_output_hidden_state(org_output: Tensor,
                              sharded_output: Tensor,
                              stage_manager: Optional[PipelineStageManager] = None,
                              atol: float = 1e-5,
                              rtol: float = 1e-3):

    org_hidden_state = org_output.last_hidden_state

    if stage_manager is None:
        sharded_hidden_state = sharded_output.last_hidden_state

    if stage_manager and stage_manager.is_last_stage():
        sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0)

    assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=atol, rtol=rtol), \
        f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"


def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
    assert torch.allclose(org_loss, sharded_loss, atol=atol, rtol=rtol), \
        f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"


def check_weight(org_model: Module,
                 sharded_model: Module,
                 layer_suffix: List[str],
                 tp_group: Optional[ProcessGroup] = None,
                 dim: int = 0,
                 atol: float = 1e-5,
                 rtol: float = 1e-3,
                 verbose: bool = False):

    for suffix in layer_suffix:
        org_weight = getattr_(org_model, suffix).weight
        sharded_weight = getattr_(sharded_model, suffix).weight

        if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
            sharded_weight_list = [
                torch.zeros([*sharded_weight.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
            ]
            dist.all_gather(sharded_weight_list, sharded_weight, tp_group)
            sharded_weight = torch.cat(sharded_weight_list, dim=dim)

        if verbose and dist.get_rank() == 0:
            print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")

        assert torch.allclose(org_weight, sharded_weight, atol=atol, rtol=rtol), \
            f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}"


def check_grad(org_model: Module,
               sharded_model: Module,
               layer_suffix: List[str],
               tp_group: ProcessGroup = None,
               dim: int = 0,
               atol: float = 1e-5,
               rtol: float = 1e-3,
               verbose: bool = False):

217
    for suffix in layer_suffix:
218
        org_grad = getattr_(org_model, suffix).weight.grad
219
220
221
222
        shard_grad = getattr_(sharded_model, suffix).weight.grad
        shard_weight = getattr_(sharded_model, suffix).weight

        if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
223
224
225
226
227
228
229
230
231
232
            shard_grad_list = [
                torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
            ]
            dist.all_gather(shard_grad_list, shard_grad, tp_group)
            shard_grad = torch.cat(shard_grad_list, dim=dim)

        # embedding may be resized when using tensor parallel
        if shard_grad.shape[0] > org_grad.shape[0]:
            shard_grad = shard_grad[:org_grad.shape[0], :]

233
        if verbose and dist.get_rank() == 0:
234
            print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
235
        assert torch.allclose(
236
237
            org_grad, shard_grad, rtol=rtol, atol=atol
        ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"