test_engine.py 2.54 KB
Newer Older
1
import pytest
2
3

import colossalai
4
5
from colossalai.legacy.amp import AMP_TYPE
from colossalai.legacy.core import global_context as gpc
6
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
7
8
9
10
11
12
13
from tests.components_to_test.registry import non_distributed_component_funcs

CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
              fp16=dict(mode=None),
              clip_grad_norm=1.0)


14
15
16
@parameterize('model_name', ['repeated_computed_layers', 'resnet18', 'repeated_computed_layers'])
@parameterize('amp_mode', [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None])
def run_train(model_name, amp_mode):
ver217's avatar
ver217 committed
17
    # FIXME: test bert
18
19
20
21
22
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
    gpc.config.fp16['mode'] = amp_mode
    model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()

    model = model_builder(checkpoint=False)
23
24
25
26
27
    engine, train_dataloader, *args = colossalai.legacy.initialize(model=model,
                                                                   optimizer=optimizer_class(model.parameters(),
                                                                                             lr=1e-3),
                                                                   criterion=criterion,
                                                                   train_dataloader=train_dataloader)
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

    try:
        engine.train()
        for data, label in train_dataloader:
            engine.zero_grad()
            data = data.cuda()
            label = label.cuda()
            if criterion:
                output = engine(data)
                loss = engine.criterion(output, label)
            else:
                loss = engine(data, label)
            engine.backward(loss)
            engine.step()
            break
    except IndexError:
        # if using apex amp, NetWithRepeatedlyComputedLayers will raise an index out of range issue
        # the following check fails in apex
        # if cached_x.grad_fn.next_functions[1][0].variable is not x:
        pass
48
49
50
51


def run_engine(rank, world_size, port):
    # init dist env
52
53
54
55
56
57
    colossalai.legacy.launch(config=CONFIG,
                             rank=rank,
                             world_size=world_size,
                             host='localhost',
                             port=port,
                             backend='nccl')
58
    run_train()
59
60
61


@pytest.mark.dist
62
@rerun_if_address_is_in_use()
63
def test_engine():
64
    spawn(run_engine, 2)
65
66
67
68


if __name__ == '__main__':
    test_engine()