"examples/language/openmoe/model/openmoe_base_config.json" did not exist on "8993c8a8170ac116a551840e3d442af78bedc53e"
_utils.py 12.4 KB
Newer Older
1
import copy
2
import math
3
from contextlib import nullcontext
4
from typing import Any, Callable, Dict, List, Optional
5

6
import torch
7
import torch.distributed as dist
8
9
10
from torch import Tensor
from torch import distributed as dist
from torch.distributed import ProcessGroup
11
from torch.nn import Module
12
from torch.optim import Adam, Optimizer
13

14
15
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
16
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
17
from colossalai.lazy import LazyInitContext
18
from colossalai.pipeline.stage_manager import PipelineStageManager
19
from colossalai.shardformer import ShardConfig, ShardFormer
20
from colossalai.shardformer._utils import getattr_
21
from colossalai.shardformer.policies.auto_policy import Policy
22
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
23
24


25
26
27
28
29
30
31
32
33
def build_model(
    model_fn,
    enable_fused_normalization=True,
    enable_tensor_parallelism=True,
    enable_flash_attention=False,
    enable_jit_fused=False,
    enable_sequence_parallelism=False,
    use_lazy_init: bool = False,
):
34
    # create new model
35
36
37
38
39
40
41
    ctx = LazyInitContext() if use_lazy_init else nullcontext()
    with ctx:
        # create new model
        org_model = model_fn()
        model_copy = copy.deepcopy(org_model)
    if use_lazy_init:
        ctx.materialize(org_model)
42
    # shard model
43
44
45
46
47
48
49
    shard_config = ShardConfig(
        enable_fused_normalization=enable_fused_normalization,
        enable_tensor_parallelism=enable_tensor_parallelism,
        enable_flash_attention=enable_flash_attention,
        enable_jit_fused=enable_jit_fused,
        enable_sequence_parallelism=enable_sequence_parallelism,
    )
50
    model_copy = copy.deepcopy(org_model)
51
    shard_former = ShardFormer(shard_config=shard_config)
ver217's avatar
ver217 committed
52
    sharded_model, shared_params = shard_former.optimize(model_copy)
53
    return org_model.cuda(), sharded_model.cuda()
54
55


56
57
58
59
60
61
62
63
def build_pipeline_model(
    model_fn,
    stage_manager=None,
    enable_fused_normalization=False,
    enable_tensor_parallelism=False,
    use_lazy_init: bool = False,
    policy: Optional[Policy] = None,
):
64
65
66
67
68
69
70
71
72
    ctx = LazyInitContext() if use_lazy_init else nullcontext()
    with ctx:
        # create new model
        org_model = model_fn()
        model_copy = copy.deepcopy(org_model)
    if use_lazy_init:
        ctx.materialize(org_model)

    # shard model
73
74
75
76
77
    shard_config = ShardConfig(
        enable_fused_normalization=enable_fused_normalization,
        enable_tensor_parallelism=enable_tensor_parallelism,
        pipeline_stage_manager=stage_manager,
    )
Jianghai's avatar
Jianghai committed
78

79
    shard_former = ShardFormer(shard_config=shard_config)
Jianghai's avatar
Jianghai committed
80
    sharded_model, shared_params = shard_former.optimize(model_copy, policy=policy)
81
82
83
    return org_model.cuda(), sharded_model.cuda()


84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
    # prepare input
    data = data_gen_fn()
    data = {k: v.cuda() for k, v in data.items()}
    # switch to train mode
    original_model.train()
    sharded_model.train()
    # run forward
    org_output = original_model(**data)
    org_output = output_transform_fn(org_output)
    org_loss = loss_fn(org_output)

    shard_output = sharded_model(**data)
    shard_output = output_transform_fn(shard_output)
    shard_loss = loss_fn(shard_output)
99
    return org_output, org_loss, shard_output, shard_loss
100
101


102
def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""):
103
104
105
    org_sd = org_model.state_dict()
    shard_sd = sharded_model.state_dict()
    for k, v in org_sd.items():
106
        assert k in shard_sd, f"{name} {k} not in sharded model"
107
        shard_v = shard_sd[k]
108
109
110
        assert v.shape == shard_v.shape, f"{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}"
        assert v.dtype == shard_v.dtype, f"{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}"
        assert torch.equal(v, shard_v), f"{name} {k} value mismatch"
111
112


113
114
def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any]):
    use_lazy_init = False
115
116
    if "use_lazy_init" in test_config:
        use_lazy_init = test_config.pop("use_lazy_init")
117

118
    ctx = LazyInitContext() if use_lazy_init else nullcontext()
119
    with ctx:
120
        org_model = model_fn()
121
122
        sharded_model = copy.deepcopy(org_model)
    if use_lazy_init:
123
        ctx.materialize(org_model)
124

125
    org_model = org_model.cuda()
126
127
128
129
    org_optimizer = Adam(org_model.parameters(), lr=1e-3)
    sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
    criterion = loss_fn

130
131
    plugin = HybridParallelPlugin(**test_config)
    booster = Booster(plugin=plugin)
132

133
    sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
134
135
136
    return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster


137
138
139
140
141
142
143
144
145
def run_forward_backward_with_hybrid_plugin(
    org_model: Module,
    sharded_model: Module,
    sharded_optimizer: Optimizer,
    data_gen_fn: Callable,
    output_transform_fn: Callable,
    criterion: Callable,
    booster: Booster,
):
146
147
    org_model.cuda()
    sharded_model.cuda()
148
149
150
151
152
153
154

    def _criterion(outputs, inputs):
        outputs = output_transform_fn(outputs)
        loss = criterion(outputs)
        return loss

    data = data_gen_fn()
155
156

    if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0:
157
        seq_len = data["input_ids"].shape[-1]
158
159
        lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
        times = lcm // seq_len
160
        input_shape = data["input_ids"].shape
161
162
        for k, v in data.items():
            if v.shape == input_shape:
163
                data[k] = v.repeat(input_shape[:-1] + (input_shape[-1] * times,))
164

165
166
    sharded_model.train()
    if booster.plugin.stage_manager is not None:
167
        for k, v in data.items():
168
            if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
169
170
                new_shape = [1] * v.dim()
                new_shape[0] = 4
171
                data[k] = v.to("cuda").repeat(*new_shape)
172

173
        data_iter = iter([data])
174
175
176
177
        sharded_output = booster.execute_pipeline(
            data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True
        )
        sharded_loss = sharded_output["loss"]
178
179
180
    else:
        data = {k: v.cuda() for k, v in data.items()}
        sharded_output = sharded_model(**data)
181

182
        sharded_loss = criterion(sharded_output)
183
        sharded_optimizer.backward(sharded_loss)
184
185

    org_model.train()
186
    data = {k: v.cuda() for k, v in data.items()}
187
    org_output = org_model(**data)
188

189
190
191
192
193
194
    org_loss = criterion(org_output)
    org_loss.backward()

    return org_loss, org_output, sharded_loss, sharded_output


195
196
197
198
199
200
201
202
def check_output_hidden_state(
    org_output: Tensor,
    sharded_output: Tensor,
    stage_manager: Optional[PipelineStageManager] = None,
    atol: float = 1e-5,
    rtol: float = 1e-3,
    dim: int = 0,
):
203
204
205
    org_hidden_state = org_output.last_hidden_state

    if stage_manager and stage_manager.is_last_stage():
206
        sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
Jianghai's avatar
Jianghai committed
207
208
    else:
        sharded_hidden_state = sharded_output.last_hidden_state
209

210
211
212
    assert torch.allclose(
        org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol
    ), f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
213
214
215


def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    assert torch.allclose(
        org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol
    ), f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"


def check_weight(
    org_model: Module,
    sharded_model: Module,
    layer_suffix: List[str],
    tp_group: Optional[ProcessGroup] = None,
    dim: int = 0,
    atol: float = 1e-5,
    rtol: float = 1e-3,
    verbose: bool = False,
):
231
232
233
234
235
236
    for suffix in layer_suffix:
        org_weight = getattr_(org_model, suffix).weight
        sharded_weight = getattr_(sharded_model, suffix).weight

        if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
            sharded_weight_list = [
237
                torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group))
238
239
240
241
242
243
244
            ]
            dist.all_gather(sharded_weight_list, sharded_weight, tp_group)
            sharded_weight = torch.cat(sharded_weight_list, dim=dim)

        if verbose and dist.get_rank() == 0:
            print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")

245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        assert torch.allclose(
            org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol
        ), f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}"


def get_grad_tensors_for_check(
    org_model: Module,
    sharded_model: Module,
    layer_suffix: List[str],
    tp_group: ProcessGroup = None,
    dim: int = 0,
    atol: float = 1e-5,
    rtol: float = 1e-3,
    verbose: bool = False,
    name: str = None,
):
261
262
263
264
265
266
    grad_to_check = {}
    for suffix in layer_suffix:
        org_grad = getattr_(org_model, suffix).weight.grad
        shard_grad = getattr_(sharded_model, suffix).weight.grad
        shard_weight = getattr_(sharded_model, suffix).weight
        if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
267
            shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
268
269
270
271
272
            dist.all_gather(shard_grad_list, shard_grad, tp_group)
            shard_grad = torch.cat(shard_grad_list, dim=dim)

        # embedding may be resized when using tensor parallel
        if shard_grad.shape[0] > org_grad.shape[0]:
273
            shard_grad = shard_grad[: org_grad.shape[0], :]
274
275
276
277
278
279
280
        if verbose and dist.get_rank() == 0:
            print(f"'{suffix}' grad: {org_grad}, {shard_grad}")

        grad_to_check[suffix] = {
            "org_grad": org_grad.float(),
            "shard_grad": shard_grad.float(),
            "rtol": rtol,
281
            "atol": atol,
282
283
284
285
286
287
        }

    return grad_to_check


# used by sam/blip2
288
289
290
291
292
293
294
295
296
297
def check_grad(
    org_model: Module,
    sharded_model: Module,
    layer_suffix: List[str],
    tp_group: ProcessGroup = None,
    dim: int = 0,
    atol: float = 1e-5,
    rtol: float = 1e-3,
    verbose: bool = False,
):
298
    for suffix in layer_suffix:
299
        org_grad = getattr_(org_model, suffix).weight.grad
300
301
302
        shard_grad = getattr_(sharded_model, suffix).weight.grad
        shard_weight = getattr_(sharded_model, suffix).weight
        if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
303
            shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
304
305
306
307
308
            dist.all_gather(shard_grad_list, shard_grad, tp_group)
            shard_grad = torch.cat(shard_grad_list, dim=dim)

        # embedding may be resized when using tensor parallel
        if shard_grad.shape[0] > org_grad.shape[0]:
309
            shard_grad = shard_grad[: org_grad.shape[0], :]
310
        if verbose and dist.get_rank() == 0:
311
            print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
312

313
        assert torch.allclose(
314
            org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol
315
        ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
316
317


318
319
320
def unwrap_model(
    module: Module, base_model_class_name: Optional[str] = None, base_model_attribute_name: Optional[str] = None
):
321
322
323
324
325
326
327
    if isinstance(module, HybridParallelModule):
        module = module.unwrap()
    if base_model_class_name is None:
        return module
    if module.__class__.__name__ == base_model_class_name:
        return module
    return getattr(module, base_model_attribute_name, None)
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342


def check_all_grad_tensors(check_tensors):
    """
    "org_grad": tensor to be compared from the original model
    "shard_grad": tensor to be compared from the sharded model
    """
    for suffix, check_info in check_tensors.items():
        org_grad = check_info["org_grad"]
        shard_grad = check_info["shard_grad"]
        rtol = check_info["rtol"]
        atol = check_info["atol"]
        assert torch.allclose(
            org_grad, shard_grad, atol=atol, rtol=rtol
        ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"