_utils.py 1.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import copy

from colossalai.shardformer import ShardConfig, ShardFormer


def build_model(world_size, model_fn):
    # create new model
    org_model = model_fn().cuda()

    # shard model
11
    shard_config = ShardConfig(tensor_parallel_size=world_size, fused_layernorm=True)
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    model_copy = copy.deepcopy(org_model)
    shard_former = ShardFormer(shard_config=shard_config)
    shard_former.init_distributed()
    sharded_model = shard_former.shard_model(model_copy)

    return org_model, sharded_model


def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
    # prepare input
    data = data_gen_fn()
    data = {k: v.cuda() for k, v in data.items()}

    # switch to train mode
    original_model.train()
    sharded_model.train()
    # run forward
    org_output = original_model(**data)
    org_output = output_transform_fn(org_output)
    org_loss = loss_fn(org_output)

    shard_output = sharded_model(**data)
    shard_output = output_transform_fn(shard_output)
    shard_loss = loss_fn(shard_output)
36
    return org_output, org_loss, shard_output, shard_loss