_utils.py 2.94 KB
Newer Older
1
import copy
2
from contextlib import nullcontext
3

4
5
6
import torch
from torch.nn import Module

7
from colossalai.lazy import LazyInitContext
8
9
10
from colossalai.shardformer import ShardConfig, ShardFormer


11
12
13
14
15
16
17
18
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False):
    ctx = LazyInitContext() if use_lazy_init else nullcontext()
    with ctx:
        # create new model
        org_model = model_fn()
        model_copy = copy.deepcopy(org_model)
    if use_lazy_init:
        ctx.materialize(org_model)
19
    # shard model
20
21
    shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
                               enable_tensor_parallelism=enable_tensor_parallelism)
22
    shard_former = ShardFormer(shard_config=shard_config)
ver217's avatar
ver217 committed
23
    sharded_model, shared_params = shard_former.optimize(model_copy)
24
    return org_model.cuda(), sharded_model.cuda()
25
26


27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def build_pipeline_model(model_fn,
                         stage_manager=None,
                         enable_fused_normalization=False,
                         enable_tensor_parallelism=False,
                         use_lazy_init: bool = False):
    ctx = LazyInitContext() if use_lazy_init else nullcontext()
    with ctx:
        # create new model
        org_model = model_fn()
        model_copy = copy.deepcopy(org_model)
    if use_lazy_init:
        ctx.materialize(org_model)

    # shard model
    shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
                               enable_tensor_parallelism=enable_tensor_parallelism,
                               pipeline_stage_manager=stage_manager)
Jianghai's avatar
Jianghai committed
44

45
46
47
48
49
    shard_former = ShardFormer(shard_config=shard_config)
    sharded_model, shared_params = shard_former.optimize(model_copy)
    return org_model.cuda(), sharded_model.cuda()


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
    # prepare input
    data = data_gen_fn()
    data = {k: v.cuda() for k, v in data.items()}
    # switch to train mode
    original_model.train()
    sharded_model.train()
    # run forward
    org_output = original_model(**data)
    org_output = output_transform_fn(org_output)
    org_loss = loss_fn(org_output)

    shard_output = sharded_model(**data)
    shard_output = output_transform_fn(shard_output)
    shard_loss = loss_fn(shard_output)
65
    return org_output, org_loss, shard_output, shard_loss
66
67
68
69
70
71
72
73
74
75
76


def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''):
    org_sd = org_model.state_dict()
    shard_sd = sharded_model.state_dict()
    for k, v in org_sd.items():
        assert k in shard_sd, f'{name} {k} not in sharded model'
        shard_v = shard_sd[k]
        assert v.shape == shard_v.shape, f'{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}'
        assert v.dtype == shard_v.dtype, f'{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}'
        assert torch.equal(v, shard_v), f'{name} {k} value mismatch'