test_engine.py 2.14 KB
Newer Older
1
import pytest
2
import torch
3
4

import colossalai
5
6
from colossalai.legacy.amp import AMP_TYPE
from colossalai.legacy.core import global_context as gpc
7
8
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
9

10
11
12
CONFIG = dict(
    parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), fp16=dict(mode=None), clip_grad_norm=1.0
)
13
14


15
16
@parameterize("model_name", ["repeated_computed_layers", "resnet18", "repeated_computed_layers"])
@parameterize("amp_mode", [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None])
17
def run_train(model_name, amp_mode):
ver217's avatar
ver217 committed
18
    # FIXME: test bert
19
20
21
    model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
    train_dataloader = DummyDataloader(data_gen_fn)
    criterion = lambda x: x.sum()
22
    gpc.config.fp16["mode"] = amp_mode
23

24
    model = model_builder()
25
26
    engine, train_dataloader, *args = colossalai.legacy.initialize(
        model=model,
27
        optimizer=torch.optim.Adam(model.parameters(), lr=1e-3),
28
29
30
        criterion=criterion,
        train_dataloader=train_dataloader,
    )
31
32
33

    try:
        engine.train()
34
        for data in train_dataloader:
35
            engine.zero_grad()
36
            data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
37
            if criterion:
38
39
                output = engine(**data)
                loss = engine.criterion(output)
40
            else:
41
                loss = engine(**data)
42
43
44
45
46
47
48
49
            engine.backward(loss)
            engine.step()
            break
    except IndexError:
        # if using apex amp, NetWithRepeatedlyComputedLayers will raise an index out of range issue
        # the following check fails in apex
        # if cached_x.grad_fn.next_functions[1][0].variable is not x:
        pass
50
51
52
53


def run_engine(rank, world_size, port):
    # init dist env
54
55
56
    colossalai.legacy.launch(
        config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl"
    )
57
    run_train()
58
59
60


@pytest.mark.dist
61
@rerun_if_address_is_in_use()
62
def test_engine():
63
    spawn(run_engine, 2)
64
65


66
if __name__ == "__main__":
67
    test_engine()