test_gradient_accumluation.py 3.14 KB
Newer Older
Frank Lee's avatar
Frank Lee committed
1
import os
2
3
4
5
from functools import partial
from pathlib import Path

import colossalai
Frank Lee's avatar
Frank Lee committed
6
7
8
9
10
11
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
12
13
14
from colossalai.utils import free_port, get_dataloader
from torch.optim import Adam
from torchvision import transforms
Frank Lee's avatar
Frank Lee committed
15
from torchvision.datasets import CIFAR10
16
from torchvision.models import resnet18
Frank Lee's avatar
Frank Lee committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

# Config
BATCH_SIZE = 16
IMG_SIZE = 224
NUM_CLASSES = 10

CONFIG = dict(
    parallel=dict(
        pipeline=dict(size=1),
        tensor=dict(size=1, mode=None)
    ),
    clip_grad_norm=1.0,
    gradient_accumulation=4
)


33
def run_no_pipeline(rank, world_size, port):
Frank Lee's avatar
Frank Lee committed
34
35
36
37
38
39
40

    # init dist env
    colossalai.launch(
        config=CONFIG,
        rank=rank,
        world_size=world_size,
        host='localhost',
41
        port=port,
Frank Lee's avatar
Frank Lee committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        backend='nccl'
    )

    # build model
    model = resnet18(num_classes=10)

    # build dataloaders
    train_dataset = CIFAR10(
        root=Path(os.environ['DATA']),
        download=True,
        transform=transforms.Compose(
            [
                transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
            ]
        )
    )
    train_dataloader = get_dataloader(dataset=train_dataset,
                                      shuffle=True,
                                      batch_size=BATCH_SIZE,
                                      pin_memory=True,
                                      drop_last=True)

    # build optimizer
    optimizer = Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    engine, train_dataloader, *args = colossalai.initialize(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        train_dataloader=train_dataloader
    )
    logger = get_dist_logger()
    rank = torch.distributed.get_rank()
    param_track = []
    grad_track = []
    next(model.parameters()).retain_grad()

    engine.train()
    step = 0
    for img, label in train_dataloader:
        engine.zero_grad()
        img = img.cuda()
        label = label.cuda()
        output = engine(img)
        loss = engine.criterion(output, label)
        engine.backward(loss)
        engine.step()

        # check
        param_track.append(next(model.parameters())[0].clone())
        grad_track.append(next(model.parameters()).grad[0].clone())
        step += 1
        if step == CONFIG['gradient_accumulation']:
            break

    assert not torch.all(grad_track[0] == grad_track[-1]), 'grad should be different in different iterations'
    assert torch.all(param_track[0] == param_track[1]) and not torch.all(param_track[0] == param_track[-1]), \
        'param should be the same in the first few iterations and only changed in the last iteration'

    gpc.destroy()
Frank Lee's avatar
Frank Lee committed
105
    torch.cuda.empty_cache()
Frank Lee's avatar
Frank Lee committed
106
107
108
109


@pytest.mark.dist
def test_engine():
Frank Lee's avatar
Frank Lee committed
110
    world_size = 4
111
    func = partial(run_no_pipeline, world_size=world_size, port=free_port())
Frank Lee's avatar
Frank Lee committed
112
    mp.spawn(func, nprocs=world_size)
Frank Lee's avatar
Frank Lee committed
113
114
115
116


if __name__ == '__main__':
    test_engine()