test_gradient_accumluation.py 3.19 KB
Newer Older
Frank Lee's avatar
Frank Lee committed
1
import os
2
3
4
5
from functools import partial
from pathlib import Path

import colossalai
Frank Lee's avatar
Frank Lee committed
6
7
8
9
10
11
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
12
13
14
from colossalai.utils import free_port, get_dataloader
from torch.optim import Adam
from torchvision import transforms
Frank Lee's avatar
Frank Lee committed
15
from torchvision.datasets import CIFAR10
16
from torchvision.models import resnet18
Frank Lee's avatar
Frank Lee committed
17
18

# Config
19
BATCH_SIZE = 2
Frank Lee's avatar
Frank Lee committed
20
21
NUM_CLASSES = 10

22
23
24
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
              clip_grad_norm=1.0,
              gradient_accumulation=4)
Frank Lee's avatar
Frank Lee committed
25
26


27
def run_no_pipeline(rank, world_size, port):
Frank Lee's avatar
Frank Lee committed
28
29

    # init dist env
30
    colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
Frank Lee's avatar
Frank Lee committed
31
32
33
34
35

    # build model
    model = resnet18(num_classes=10)

    # build dataloaders
36
37
38
39
40
41
    train_dataset = CIFAR10(root=Path(os.environ['DATA']),
                            download=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                            ]))
Frank Lee's avatar
Frank Lee committed
42
43
44
45
46
47
48
49
50
51
    train_dataloader = get_dataloader(dataset=train_dataset,
                                      shuffle=True,
                                      batch_size=BATCH_SIZE,
                                      pin_memory=True,
                                      drop_last=True)

    # build optimizer
    optimizer = Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

52
53
54
55
    engine, train_dataloader, *args = colossalai.initialize(model=model,
                                                            optimizer=optimizer,
                                                            criterion=criterion,
                                                            train_dataloader=train_dataloader)
Frank Lee's avatar
Frank Lee committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    logger = get_dist_logger()
    rank = torch.distributed.get_rank()
    param_track = []
    grad_track = []
    next(model.parameters()).retain_grad()

    engine.train()
    step = 0
    for img, label in train_dataloader:
        engine.zero_grad()
        img = img.cuda()
        label = label.cuda()
        output = engine(img)
        loss = engine.criterion(output, label)
        engine.backward(loss)
        engine.step()

        # check
        param_track.append(next(model.parameters())[0].clone())
        grad_track.append(next(model.parameters()).grad[0].clone())
        step += 1
        if step == CONFIG['gradient_accumulation']:
            break

    assert not torch.all(grad_track[0] == grad_track[-1]), 'grad should be different in different iterations'
    assert torch.all(param_track[0] == param_track[1]) and not torch.all(param_track[0] == param_track[-1]), \
        'param should be the same in the first few iterations and only changed in the last iteration'

    gpc.destroy()
Frank Lee's avatar
Frank Lee committed
85
    torch.cuda.empty_cache()
Frank Lee's avatar
Frank Lee committed
86
87
88
89


@pytest.mark.dist
def test_engine():
Frank Lee's avatar
Frank Lee committed
90
    world_size = 4
91
    func = partial(run_no_pipeline, world_size=world_size, port=free_port())
Frank Lee's avatar
Frank Lee committed
92
    mp.spawn(func, nprocs=world_size)
Frank Lee's avatar
Frank Lee committed
93
94
95
96


if __name__ == '__main__':
    test_engine()