"vscode:/vscode.git/clone" did not exist on "63f250bbd49adf5fac8f670bb98181f81e5d4369"
test_engine.py 2.91 KB
Newer Older
1
2
3
4
5
6
from functools import partial

import colossalai
import pytest
import torch.multiprocessing as mp
from colossalai.amp import AMP_TYPE
ver217's avatar
ver217 committed
7
from colossalai.context import Config
8
9
10
11
12
13
14
15
16
17
from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from tests.components_to_test.registry import non_distributed_component_funcs

CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
              fp16=dict(mode=None),
              clip_grad_norm=1.0)


def run_train():
ver217's avatar
ver217 committed
18
19
20
21
    test_models = ['repeated_computed_layers', 'resnet18', 'repeated_computed_layers']
    # FIXME: test bert
    for model_name in test_models:
        get_components_func = non_distributed_component_funcs.get_callable(model_name)
22
        model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
23

24
        model = model_builder(checkpoint=False)
25
        engine, train_dataloader, *args = colossalai.initialize(model=model,
26
                                                                optimizer=optimizer_class(model.parameters(), lr=1e-3),
27
28
29
30
31
                                                                criterion=criterion,
                                                                train_dataloader=train_dataloader)

        try:
            engine.train()
jiaruifang's avatar
jiaruifang committed
32
            for data, label in train_dataloader:
33
                engine.zero_grad()
jiaruifang's avatar
jiaruifang committed
34
                data = data.cuda()
35
                label = label.cuda()
jiaruifang's avatar
jiaruifang committed
36
37
38
39
40
                if criterion:
                    output = engine(data)
                    loss = engine.criterion(output, label)
                else:
                    loss = engine(data, label)
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
                engine.backward(loss)
                engine.step()
                break
        except IndexError:
            # if using apex amp, NetWithRepeatedlyComputedLayers will raise an index out of range issue
            # the following check fails in apex
            # if cached_x.grad_fn.next_functions[1][0].variable is not x:
            continue


def run_with_no_amp():
    run_train()


def run_with_torch_amp():
    # hack config
    CONFIG['fp16']['mode'] = AMP_TYPE.TORCH
    gpc._config = Config(CONFIG)
    run_train()


def run_with_apex_amp():
    # hack config
    CONFIG['fp16']['mode'] = AMP_TYPE.APEX
    gpc._config = Config(CONFIG)
    run_train()


def run_with_naive_amp():
    # hack config
    CONFIG['fp16']['mode'] = AMP_TYPE.NAIVE
    gpc._config = Config(CONFIG)
    run_train()


def run_engine(rank, world_size, port):
    # init dist env
    colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    run_with_no_amp()
jiaruifang's avatar
jiaruifang committed
80
81
82
    run_with_torch_amp()
    run_with_apex_amp()
    run_with_naive_amp()
83
84
85
86


@pytest.mark.dist
def test_engine():
87
    world_size = 2
88
89
90
91
92
93
    run_func = partial(run_engine, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_engine()