test_gradient_accumluation.py 3.34 KB
Newer Older
Frank Lee's avatar
Frank Lee committed
1
import os
2
3
4
5
from functools import partial
from pathlib import Path

import colossalai
Frank Lee's avatar
Frank Lee committed
6
7
8
9
10
11
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
12
from colossalai.utils import free_port, get_dataloader
13
from colossalai.testing import rerun_on_exception
14
15
from torch.optim import Adam
from torchvision import transforms
Frank Lee's avatar
Frank Lee committed
16
from torchvision.datasets import CIFAR10
17
from torchvision.models import resnet18
Frank Lee's avatar
Frank Lee committed
18
19

# Config
20
BATCH_SIZE = 2
Frank Lee's avatar
Frank Lee committed
21
22
NUM_CLASSES = 10

23
24
25
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
              clip_grad_norm=1.0,
              gradient_accumulation=4)
Frank Lee's avatar
Frank Lee committed
26
27


28
def run_no_pipeline(rank, world_size, port):
Frank Lee's avatar
Frank Lee committed
29
30

    # init dist env
31
    colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
Frank Lee's avatar
Frank Lee committed
32
33
34
35
36

    # build model
    model = resnet18(num_classes=10)

    # build dataloaders
37
38
39
40
41
42
    train_dataset = CIFAR10(root=Path(os.environ['DATA']),
                            download=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                            ]))
Frank Lee's avatar
Frank Lee committed
43
44
45
46
47
48
49
50
51
52
    train_dataloader = get_dataloader(dataset=train_dataset,
                                      shuffle=True,
                                      batch_size=BATCH_SIZE,
                                      pin_memory=True,
                                      drop_last=True)

    # build optimizer
    optimizer = Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

53
54
55
56
    engine, train_dataloader, *args = colossalai.initialize(model=model,
                                                            optimizer=optimizer,
                                                            criterion=criterion,
                                                            train_dataloader=train_dataloader)
Frank Lee's avatar
Frank Lee committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    logger = get_dist_logger()
    rank = torch.distributed.get_rank()
    param_track = []
    grad_track = []
    next(model.parameters()).retain_grad()

    engine.train()
    step = 0
    for img, label in train_dataloader:
        engine.zero_grad()
        img = img.cuda()
        label = label.cuda()
        output = engine(img)
        loss = engine.criterion(output, label)
        engine.backward(loss)
        engine.step()

        # check
        param_track.append(next(model.parameters())[0].clone())
        grad_track.append(next(model.parameters()).grad[0].clone())
        step += 1
        if step == CONFIG['gradient_accumulation']:
            break

    assert not torch.all(grad_track[0] == grad_track[-1]), 'grad should be different in different iterations'
    assert torch.all(param_track[0] == param_track[1]) and not torch.all(param_track[0] == param_track[-1]), \
        'param should be the same in the first few iterations and only changed in the last iteration'

    gpc.destroy()
Frank Lee's avatar
Frank Lee committed
86
    torch.cuda.empty_cache()
Frank Lee's avatar
Frank Lee committed
87
88
89


@pytest.mark.dist
90
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
Frank Lee's avatar
Frank Lee committed
91
def test_engine():
Frank Lee's avatar
Frank Lee committed
92
    world_size = 4
93
    func = partial(run_no_pipeline, world_size=world_size, port=free_port())
Frank Lee's avatar
Frank Lee committed
94
    mp.spawn(func, nprocs=world_size)
Frank Lee's avatar
Frank Lee committed
95
96
97
98


if __name__ == '__main__':
    test_engine()