test_middleware_1f1b.py 2.36 KB
Newer Older
Ziyue Jiang's avatar
Ziyue Jiang committed
1
2
3
4
5
6
import torch
from torch import nn

from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from colossalai.fx import ColoTracer
7
from colossalai.pipeline.middleware.adaptor import get_fx_topology
Ziyue Jiang's avatar
Ziyue Jiang committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from rpc_test_utils import rpc_run, parse_args, MLP
from functools import partial

# global variable for model created
batch_size = 16
dim = 10

def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
    model.eval()
    tracer = ColoTracer()
    meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
    graph = tracer.trace(root=model, meta_args=meta_args)
    gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
    annotated_model = balanced_split_pass(gm, stage_num)
22
23
24
25
26
27
    top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
    topo = get_fx_topology(top_module)
    for submodule in split_submodules:
        if isinstance(submodule, torch.fx.GraphModule):
            setattr(submodule, '_topo', topo)
    return split_submodules[pp_rank+1]
Ziyue Jiang's avatar
Ziyue Jiang committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

def partition(data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int):
    torch.manual_seed(1024)
    model = MLP(dim, stage_num * 3)
    partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
    return partition

def run_master(args):
    torch.manual_seed(100)

    epoch = args.epoch
    device = args.device
    stage_num = args.world_size
    chunk = args.chunk
    num_microbatches = args.num_microbatches
    use_checkpoint = args.use_checkpoint

    input_sample = torch.randn((batch_size, dim), device=device)
    
    def data_gen():
        x = torch.zeros((batch_size, dim))
        kwargs = dict(x=x)
        return kwargs
    
    data_kwargs = data_gen()
    engine = OneFOneBPipelineEngine(partition_fn=partial(partition, data_kwargs),
                                    stage_num=stage_num,
                                    num_microbatches=num_microbatches,
                                    device=device,
                                    chunk=chunk,
                                    checkpoint=use_checkpoint)

    for _ in range(epoch):
        logits = engine.forward_backward({'x': input_sample}, forward_only=True)

if __name__ == "__main__":
    args = parse_args()
    rpc_run(args, run_master)