test_zero_engine.py 4.05 KB
Newer Older
1
2
3
4
5
6
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from functools import partial

import colossalai
7
8
9
10
11
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
12
from colossalai.utils import free_port
13
14
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model.utils import col_model_deepcopy
15
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
16
from tests.components_to_test.registry import non_distributed_component_funcs
17
from torch.nn.parallel import DistributedDataParallel as DDP
18

19
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params)
20
21


22
23
def run_dist(rank, world_size, port, parallel_config):
    colossalai.launch(config=parallel_config,
24
25
26
27
28
29
30
                      rank=rank,
                      world_size=world_size,
                      host='localhost',
                      port=port,
                      backend='nccl')

    test_models = ['repeated_computed_layers', 'resnet18', 'bert']
31
32
33
    for model_name in test_models:
        get_components_func = non_distributed_component_funcs.get_callable(model_name)
        model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
34
35
        with ZeroInitContext(convert_fp16=hasattr(gpc.config, 'fp16'),
                             target_device=torch.cuda.current_device(),
36
                             shard_strategy=gpc.config.zero.model_config.shard_strategy,
37
38
39
                             shard_param=True):
            colo_model = model_builder(checkpoint=True)

40
        colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3)
41
        engine, train_dataloader, _, _ = colossalai.initialize(colo_model,
42
                                                               optimizer=colo_optimizer,
43
44
                                                               criterion=criterion,
                                                               train_dataloader=train_dataloader)
45
46
47
48
        torch_model = model_builder(checkpoint=True).half()
        col_model_deepcopy(engine.model, torch_model)
        torch_model = torch_model.cuda().float()

49
        engine.train()
50
        torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3)
51

52
53
54
        if dist.get_world_size() > 1:
            torch_model = DDP(torch_model)

55
56
        i = 0
        for data, label in train_dataloader:
57
            if i > 4:
58
59
60
61
62
63
64
65
66
67
68
                break

            data, label = data.cuda(), label.cuda()

            engine.zero_grad()
            torch_optimizer.zero_grad()

            if criterion:
                output = engine(data)
                loss = engine.criterion(output, label)

69
70
                torch_output = torch_model(data)
                torch_loss = engine.criterion(torch_output, label)
71
72
73
74
75
76
77
78
            else:
                loss = engine(data, label)
                torch_loss = torch_model(data, label)

            engine.backward(loss)
            engine.step()

            torch_loss.backward()
79
80
81
82
83

            for param in torch_model.parameters():
                if param.grad is not None:
                    assert not has_inf_or_nan(param.grad)

84
85
86
            torch_optimizer.step()
            i += 1

87
88
        if parallel_config == MP_PARALLEL_CONFIG:
            check_params(torch_model, colo_model, loose=True)
89
        elif parallel_config == ZERO_PARALLEL_CONFIG:
90
            check_sharded_model_params(torch_model, colo_model, loose=True)
91
92


93
94
95
96
# FIXME: enable this test in next PR


@pytest.mark.skip
97
98
99
100
101
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
def test_mp_engine(world_size):
    run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG)
    mp.spawn(run_func, nprocs=world_size)
102
103
104
105


@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
106
107
def test_zero_engine(world_size):
    run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG)
108
109
110
111
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
112
    test_zero_engine(world_size=4)