_utils.py 1.45 KB
Newer Older
1
import copy
2
from contextlib import nullcontext
3

4
from colossalai.lazy import LazyInitContext
5
6
7
from colossalai.shardformer import ShardConfig, ShardFormer


8
9
10
11
12
13
14
15
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False):
    ctx = LazyInitContext() if use_lazy_init else nullcontext()
    with ctx:
        # create new model
        org_model = model_fn()
        model_copy = copy.deepcopy(org_model)
    if use_lazy_init:
        ctx.materialize(org_model)
16
    # shard model
17
18
    shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
                               enable_tensor_parallelism=enable_tensor_parallelism)
19
    shard_former = ShardFormer(shard_config=shard_config)
ver217's avatar
ver217 committed
20
    sharded_model, shared_params = shard_former.optimize(model_copy)
21
    return org_model.cuda(), sharded_model.cuda()
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39


def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
    # prepare input
    data = data_gen_fn()
    data = {k: v.cuda() for k, v in data.items()}

    # switch to train mode
    original_model.train()
    sharded_model.train()
    # run forward
    org_output = original_model(**data)
    org_output = output_transform_fn(org_output)
    org_loss = loss_fn(org_output)

    shard_output = sharded_model(**data)
    shard_output = output_transform_fn(shard_output)
    shard_loss = loss_fn(shard_output)
40
    return org_output, org_loss, shard_output, shard_loss