_utils.py 2.42 KB
Newer Older
1
import copy
2
from contextlib import nullcontext
3

4
from colossalai.lazy import LazyInitContext
5
from colossalai.pipeline.stage_manager import PipelineStageManager
6
7
8
from colossalai.shardformer import ShardConfig, ShardFormer


9
10
11
12
13
14
15
16
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False):
    ctx = LazyInitContext() if use_lazy_init else nullcontext()
    with ctx:
        # create new model
        org_model = model_fn()
        model_copy = copy.deepcopy(org_model)
    if use_lazy_init:
        ctx.materialize(org_model)
17
    # shard model
18
19
    shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
                               enable_tensor_parallelism=enable_tensor_parallelism)
20
    shard_former = ShardFormer(shard_config=shard_config)
ver217's avatar
ver217 committed
21
    sharded_model, shared_params = shard_former.optimize(model_copy)
22
    return org_model.cuda(), sharded_model.cuda()
23
24


25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def build_pipeline_model(model_fn,
                         stage_manager=None,
                         enable_fused_normalization=False,
                         enable_tensor_parallelism=False,
                         use_lazy_init: bool = False):
    ctx = LazyInitContext() if use_lazy_init else nullcontext()
    with ctx:
        # create new model
        org_model = model_fn()
        model_copy = copy.deepcopy(org_model)
    if use_lazy_init:
        ctx.materialize(org_model)

    # shard model
    shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
                               enable_tensor_parallelism=enable_tensor_parallelism,
                               pipeline_stage_manager=stage_manager)
Jianghai's avatar
Jianghai committed
42

43
44
45
46
47
    shard_former = ShardFormer(shard_config=shard_config)
    sharded_model, shared_params = shard_former.optimize(model_copy)
    return org_model.cuda(), sharded_model.cuda()


48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
    # prepare input
    data = data_gen_fn()
    data = {k: v.cuda() for k, v in data.items()}
    # switch to train mode
    original_model.train()
    sharded_model.train()
    # run forward
    org_output = original_model(**data)
    org_output = output_transform_fn(org_output)
    org_loss = loss_fn(org_output)

    shard_output = sharded_model(**data)
    shard_output = output_transform_fn(shard_output)
    shard_loss = loss_fn(shard_output)
63
    return org_output, org_loss, shard_output, shard_loss