test_gradient_accumluation.py 3.34 KB
Newer Older
Frank Lee's avatar
Frank Lee committed
1
import os
2
3
4
5
from functools import partial
from pathlib import Path

import colossalai
6
from colossalai.testing.utils import rerun_if_address_is_in_use
Frank Lee's avatar
Frank Lee committed
7
8
9
10
11
12
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
13
from colossalai.utils import free_port, get_dataloader
14
from colossalai.testing import rerun_if_address_is_in_use
15
16
from torch.optim import Adam
from torchvision import transforms
Frank Lee's avatar
Frank Lee committed
17
from torchvision.datasets import CIFAR10
18
from torchvision.models import resnet18
Frank Lee's avatar
Frank Lee committed
19
20

# Config
21
BATCH_SIZE = 2
Frank Lee's avatar
Frank Lee committed
22
23
NUM_CLASSES = 10

24
25
26
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
              clip_grad_norm=1.0,
              gradient_accumulation=4)
Frank Lee's avatar
Frank Lee committed
27
28


29
def run_no_pipeline(rank, world_size, port):
Frank Lee's avatar
Frank Lee committed
30
31

    # init dist env
32
    colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
Frank Lee's avatar
Frank Lee committed
33
34
35
36
37

    # build model
    model = resnet18(num_classes=10)

    # build dataloaders
38
39
40
41
42
43
    train_dataset = CIFAR10(root=Path(os.environ['DATA']),
                            download=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                            ]))
Frank Lee's avatar
Frank Lee committed
44
45
46
47
48
49
50
51
52
53
    train_dataloader = get_dataloader(dataset=train_dataset,
                                      shuffle=True,
                                      batch_size=BATCH_SIZE,
                                      pin_memory=True,
                                      drop_last=True)

    # build optimizer
    optimizer = Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

54
55
56
57
    engine, train_dataloader, *args = colossalai.initialize(model=model,
                                                            optimizer=optimizer,
                                                            criterion=criterion,
                                                            train_dataloader=train_dataloader)
Frank Lee's avatar
Frank Lee committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    logger = get_dist_logger()
    rank = torch.distributed.get_rank()
    param_track = []
    grad_track = []
    next(model.parameters()).retain_grad()

    engine.train()
    step = 0
    for img, label in train_dataloader:
        engine.zero_grad()
        img = img.cuda()
        label = label.cuda()
        output = engine(img)
        loss = engine.criterion(output, label)
        engine.backward(loss)
        engine.step()

        # check
        param_track.append(next(model.parameters())[0].clone())
        grad_track.append(next(model.parameters()).grad[0].clone())
        step += 1
        if step == CONFIG['gradient_accumulation']:
            break

    assert not torch.all(grad_track[0] == grad_track[-1]), 'grad should be different in different iterations'
    assert torch.all(param_track[0] == param_track[1]) and not torch.all(param_track[0] == param_track[-1]), \
        'param should be the same in the first few iterations and only changed in the last iteration'

    gpc.destroy()
Frank Lee's avatar
Frank Lee committed
87
    torch.cuda.empty_cache()
Frank Lee's avatar
Frank Lee committed
88
89
90


@pytest.mark.dist
91
@rerun_if_address_is_in_use()
Frank Lee's avatar
Frank Lee committed
92
def test_engine():
Frank Lee's avatar
Frank Lee committed
93
    world_size = 4
94
    func = partial(run_no_pipeline, world_size=world_size, port=free_port())
Frank Lee's avatar
Frank Lee committed
95
    mp.spawn(func, nprocs=world_size)
Frank Lee's avatar
Frank Lee committed
96
97
98
99


if __name__ == '__main__':
    test_engine()