test_vit.py 5.55 KB
Newer Older
Jiarui Fang's avatar
Jiarui Fang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from functools import partial

import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from utils.util import set_seed, tensor_equal, tensor_shard_equal
from vit import get_training_components

import colossalai
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext


# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating.
# But for other layers, it's 1d_col split.
# Layernorm is not supported for now.
# patch_embeddings.projection has nn.Conv2d
# https://github.com/huggingface/transformers/blob/dcb08b99f44919425f8ba9be9ddcc041af8ec25e/src/transformers/models/vit/modeling_vit.py#L182
def init_1d_row_for_linear_weight_spec(model, world_size: int):
    pg = ProcessGroup(tp_degree=world_size)
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    with DistSpecManager.no_grad():
        for n, p in model.named_parameters():
            if 'weight' in n and 'layernorm' not in n and 'embeddings.patch_embeddings.projection.weight' not in n:
                p.set_process_group(pg)
                p.set_tensor_spec(*spec)


# Similarly, it's col split for Linear but row split for others.
def init_1d_col_for_linear_weight_bias_spec(model, world_size: int):
    pg = ProcessGroup(tp_degree=world_size)
    spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    with DistSpecManager.no_grad():
        for n, p in model.named_parameters():
            if ('weight' in n
                    or 'bias' in n) and 'layernorm' not in n and 'embeddings.patch_embeddings.projection' not in n:
                p.set_process_group(pg)
                p.set_tensor_spec(*spec)


def check_param_equal(model, torch_model):
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
        assert tensor_shard_equal(torch_p, p)


def check_grad_equal(model, torch_model):
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
        if (torch_p.grad.shape == p.grad.shape):
            assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True
        else:
            dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape))
            dim = dims_not_eq.item()
            world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
            rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
            assert torch.allclose(torch_p.grad.chunk(world_size, dim)[rank], p.grad, rtol=1e-3, atol=2.0) == True


def run_vit(init_spec_func, use_ddp):
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_training_components()
    with ColoInitContext(device=get_current_device()):
        model = model_builder()
    model = model.cuda()
    torch_model = model_builder().cuda()
    if use_ddp:
        model = ColoDDP(model)
        torch_model = DDP(torch_model,
                          device_ids=[gpc.get_global_rank()],
                          process_group=gpc.get_group(ParallelMode.DATA))
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
        torch_p.data.copy_(p)

    world_size = torch.distributed.get_world_size()
    init_spec_func(model, world_size)

    check_param_equal(model, torch_model)
    model.train()
    torch_model.train()
    set_seed(gpc.get_local_rank(ParallelMode.DATA))

    optimizer = optimizer_class(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    torch_optimizer = optimizer_class(torch_model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

    for i, image_dict in enumerate(train_dataloader):
        if use_ddp:
            model.zero_grad()
        else:
            optimizer.zero_grad()
        logits = model(image_dict['pixel_values'])
        torch_logits = torch_model(image_dict['pixel_values'])
        assert tensor_equal(torch_logits.logits, logits.logits)
        loss = criterion(logits.logits, image_dict['label'])
        torch_loss = criterion(torch_logits.logits, image_dict['label'])
        if use_ddp:
            model.backward(loss)
        else:
            loss.backward()
        torch_loss.backward()
        check_grad_equal(model, torch_model)
        optimizer.step()
        torch_optimizer.step()
        check_param_equal(model, torch_model)
        break


def run_dist(rank, world_size, port, use_ddp):
    if use_ddp and world_size == 1:
        return
    tp_world_size = world_size // 2 if use_ddp else world_size
    config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    run_vit(init_1d_row_for_linear_weight_spec, use_ddp)
    run_vit(init_1d_col_for_linear_weight_bias_spec, use_ddp)


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.parametrize('use_ddp', [False, True])
@rerun_if_address_is_in_use()
def test_vit(world_size, use_ddp):
    run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_vit(1, False)