test_cuda_rpc_performance.py 2.3 KB
Newer Older
1
2
3
4
5
6
import os
import time

import pytest
import torch
import torch.nn as nn
7
from rpc_test_utils import parse_args, rpc_run
8
9
10
11
from titans.dataloader.cifar10 import build_cifar
from torchvision.models import resnet50
from tqdm import tqdm

12
13
from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.pipeline.rpc import OneFOneBPipelineEngine
14
15
16
17
18
19


def flatten(x):
    return torch.flatten(x, 1)


20
def partition(pp_rank: int, chunk: int, stage_num: int):
21
22
23
24
25
    pipelinable = PipelinableContext()

    # build model partitions
    with pipelinable:
        # input : [B, 3, 32, 32]
26
27
28
        _ = resnet50()

    pipelinable.policy = "customized"
29
30
31
32
33

    exec_seq = [
        'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc'
    ]
    pipelinable.to_layer_list(exec_seq)
34
35
36
37
38
39
40
41
42
43
44
    partition = pipelinable.partition(chunk, stage_num, pp_rank)
    return partition


def run_master(args):
    batch_size = args.batch_size
    chunk = args.chunk
    device = args.device
    world_size = args.world_size
    stage_num = world_size
    num_microbatches = args.num_microbatches
45
46
47
48
49
50

    # build dataloader
    root = os.environ.get('DATA', './data')
    train_dataloader, test_dataloader = build_cifar(batch_size, root, padding=4, crop=32, resize=32)
    criterion = nn.CrossEntropyLoss()

51
    pp_engine = OneFOneBPipelineEngine(partition_fn=partition,
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
                                       stage_num=stage_num,
                                       num_microbatches=num_microbatches,
                                       device=device,
                                       chunk=chunk,
                                       criterion=criterion,
                                       checkpoint=False)

    pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
    s = time.time()

    for bx, by in tqdm(train_dataloader):
        pp_engine.forward_backward(bx, labels=by, forward_only=False)

    cost_time = time.time() - s

    print("total cost time :", cost_time)
    print("cost time per batch:", cost_time / len(train_dataloader))


@pytest.mark.skip("Test for performance, no need for CI")
def main():
    args = parse_args()
    # this is due to limitation of partition function
    args.world_size = 2
    args.chunk = 1
    rpc_run(args, run_master)


if __name__ == '__main__':
    main()