test_zero_engine.py 3.47 KB
Newer Older
1
2
3
4
5
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import copy
from functools import partial
6
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
7
8
9
10
import pytest

import colossalai
from colossalai.utils import free_port
11
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
12
13

import torch.multiprocessing as mp
14
15
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
16
17

from tests.components_to_test.registry import non_distributed_component_funcs
18
from common import check_sharded_params_padding, ZERO_PARALLEL_CONFIG, MP_PARALLEL_CONFIG, check_params
19
20


21
22
def run_dist(rank, world_size, port, parallel_config):
    colossalai.launch(config=parallel_config,
23
24
25
26
27
28
29
                      rank=rank,
                      world_size=world_size,
                      host='localhost',
                      port=port,
                      backend='nccl')

    test_models = ['repeated_computed_layers', 'resnet18', 'bert']
30
31
32
33
    for model_name in test_models:
        get_components_func = non_distributed_component_funcs.get_callable(model_name)
        model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()

34
35
        colo_model = model_builder(checkpoint=True)
        torch_model = copy.deepcopy(colo_model).cuda()
36
        torch_model.train()
37
        engine, train_dataloader, _, _ = colossalai.initialize(colo_model,
38
39
40
41
                                                               optimizer=optimizer_class,
                                                               criterion=criterion,
                                                               train_dataloader=train_dataloader)
        engine.train()
42
        torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3)
43

44
45
46
        if dist.get_world_size() > 1:
            torch_model = DDP(torch_model)

47
48
        i = 0
        for data, label in train_dataloader:
49
            if i > 4:
50
51
52
53
54
55
56
57
58
59
60
                break

            data, label = data.cuda(), label.cuda()

            engine.zero_grad()
            torch_optimizer.zero_grad()

            if criterion:
                output = engine(data)
                loss = engine.criterion(output, label)

61
62
                torch_output = torch_model(data)
                torch_loss = engine.criterion(torch_output, label)
63
64
65
66
67
68
69
70
            else:
                loss = engine(data, label)
                torch_loss = torch_model(data, label)

            engine.backward(loss)
            engine.step()

            torch_loss.backward()
71
72
73
74
75

            for param in torch_model.parameters():
                if param.grad is not None:
                    assert not has_inf_or_nan(param.grad)

76
77
78
            torch_optimizer.step()
            i += 1

79
80
        if parallel_config == MP_PARALLEL_CONFIG:
            check_params(torch_model, colo_model, loose=True)
81
        elif parallel_config == ZERO_PARALLEL_CONFIG:
82
83
84
85
86
87
88
89
            check_sharded_params_padding(torch_model, colo_model, loose=True)


@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
def test_mp_engine(world_size):
    run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG)
    mp.spawn(run_func, nprocs=world_size)
90
91
92
93


@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
94
95
def test_zero_engine(world_size):
    run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG)
96
97
98
99
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
100
    test_zero_engine(world_size=4)