"tests/kit/vscode:/vscode.git/clone" did not exist on "162203105883e7b2b0919b1feeba8531d0ecae21"
auto_parallel_demo.py 5.45 KB
Newer Older
1
import os
2
from pathlib import Path
3

4
import torch
5
from titans.utils import barrier_context
6
7
8
9
10
from torch.fx import GraphModule
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet50
from tqdm import tqdm
11
12
13
14

import colossalai
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
15
16
17
18
19
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
20
from colossalai.core import global_context as gpc
21
22
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
23
24
25
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingLR
from colossalai.utils import get_dataloader
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

DATA_ROOT = Path(os.environ.get('DATA', './data'))
BATCH_SIZE = 1024
NUM_EPOCHS = 10


def main():
    colossalai.launch_from_torch(config={})

    logger = get_dist_logger()

    with barrier_context():
        # build dataloaders
        train_dataset = CIFAR10(root=DATA_ROOT,
                                download=True,
                                transform=transforms.Compose([
                                    transforms.RandomCrop(size=32, padding=4),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
                                ]))

    test_dataset = CIFAR10(root=DATA_ROOT,
                           train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
                           ]))

    train_dataloader = get_dataloader(
        dataset=train_dataset,
        add_sampler=False,
        shuffle=True,
        batch_size=BATCH_SIZE,
        pin_memory=True,
    )

    test_dataloader = get_dataloader(
        dataset=test_dataset,
        add_sampler=False,
        batch_size=BATCH_SIZE,
        pin_memory=True,
    )

    # initialize device mesh
    physical_mesh_id = torch.arange(0, 4)
    mesh_shape = (2, 2)
    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)

    # trace the model with meta data
    tracer = ColoTracer()
    model = resnet50(num_classes=10).cuda()
    input_sample = {'x': torch.rand([1024, 3, 32, 32]).to('meta')}
    graph = tracer.trace(root=model, meta_args=input_sample)
    gm = GraphModule(model, graph, model.__class__.__name__)
    gm.recompile()

    # prepare info for solver
    solver_options = SolverOptions(fast=True)
    strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
    strategies_constructor.build_strategies_and_cost()
    cost_graph = CostGraph(strategies_constructor.leaf_strategies)
    cost_graph.simplify_graph()
    graph_analyser = GraphAnalyser(gm)

    # solve the solution
    solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
    ret = solver.call_solver_serialized_args()
    solution = list(ret[0])
    if gpc.get_global_rank() == 0:
        for index, node in enumerate(graph.nodes):
            print(node.name, node.strategies_vector[solution[index]].name)

    # process the graph for distributed training ability
    gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
    gm = runtime_apply_pass(gm)
    gm.recompile()

    # build criterion
    criterion = torch.nn.CrossEntropyLoss()

    # optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

    # lr_scheduler
    lr_scheduler = CosineAnnealingLR(optimizer, total_steps=NUM_EPOCHS)

    for epoch in range(NUM_EPOCHS):
        gm.train()
        if gpc.get_global_rank() == 0:
            train_dl = tqdm(train_dataloader)
        else:
            train_dl = train_dataloader
        for img, label in train_dl:
            img = img.cuda()
            label = label.cuda()
            optimizer.zero_grad()
            output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
            train_loss = criterion(output, label)
            train_loss.backward(train_loss)
            optimizer.step()
        lr_scheduler.step()

        gm.eval()
        correct = 0
        total = 0
        for img, label in test_dataloader:
            img = img.cuda()
            label = label.cuda()

            with torch.no_grad():
                output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
                test_loss = criterion(output, label)
            pred = torch.argmax(output, dim=-1)
            correct += torch.sum(pred == label)
            total += img.size(0)

        logger.info(
            f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}",
            ranks=[0])


if __name__ == '__main__':
    main()