test_fwd_bwd.py 4.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from functools import partial

import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
18
from tests.components_to_test import run_fwd_bwd
19
from tests.components_to_test.registry import non_distributed_component_funcs
20
from tests.test_tensor.common_utils import set_seed
21
22
23
24
25
26
27
28
29
30
31
32
33
34


def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
    chunk_manager = model.chunk_manager
    param_list = [p for p in model.parameters()]
    chunk_list = chunk_manager.get_chunks(param_list)
    for chunk in chunk_list:
        chunk_manager.access_chunk(chunk)

    for (p0, p1) in zip(model.parameters(), torch_model.parameters()):
        assert torch.allclose(p0, p1.grad, atol=1e-3, rtol=1e-5), "{}".format(torch.max(torch.abs(p0 - p1.grad)).item())


@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
35
@parameterize('keep_gather', [False, True])
36
37
38
@parameterize('model_name', ['gpt2', 'bert', 'resnet18'])
@parameterize('use_grad_checkpoint', [False, True])
def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False):
39
    set_seed(42)
40
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
41
42
43
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

    with ColoInitContext(device=get_current_device()):
44
        model = model_builder(use_grad_checkpoint)
45

46
    torch_model = model_builder(use_grad_checkpoint).cuda()
47
48
49
50
51
52
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
        torch_p.data.copy_(p.data)

    world_size = torch.distributed.get_world_size()
    config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
    config_dict[world_size]['chunk_size'] = 5000
53
    config_dict[world_size]['keep_gathered'] = keep_gather
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    chunk_manager = ChunkManager(config_dict)
    gemini_manager = GeminiManager(placement_policy, chunk_manager)
    model = ZeroDDP(model, gemini_manager, pin_memory=True)

    pg = ProcessGroup()
    amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
    torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
    torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())

    model.eval()
    torch_model.eval()

    set_seed(pg.dp_local_rank())
68
    for i, (input_ids, label) in enumerate(train_dataloader):
69
70
        # you can only test a single fwd + bwd.
        # after bwd param is grad for Gemini, due to the chunk reuse optimization.
71
72
73
        if i > 0:
            break

74
75
        torch_loss = run_fwd_bwd(torch_model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=False)
        loss = run_fwd_bwd(model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=True)
76

77
78
        assert torch.allclose(loss, torch_loss, rtol=1e-2), "{} {} {}".format(
            torch.max(torch.abs(loss - torch_loss)).item(), loss, torch_loss)
79

80
81
        # FIXME(1SAA) bert and resnet18 can not pass the check_grad
        # check_grad(model, torch_model)
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98


def run_dist(rank, world_size, port):
    config = {}
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    exam_gpt_fwd_bwd()


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
    run_func = partial(run_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
99
    test_gpt(1)