test_activation_checkpointing.py 3.85 KB
Newer Older
zbian's avatar
zbian committed
1
2
3
4
5
6
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import pytest
import torch
import torch.nn.functional as F
7

zbian's avatar
zbian committed
8
from colossalai.context.parallel_mode import ParallelMode
9
10
from colossalai.context.random import add_seed, reset_seeds, seed, set_mode
from colossalai.testing import clear_cache_before_run, parameterize
11
from colossalai.utils.activation_checkpoint import checkpoint
zbian's avatar
zbian committed
12
13
14
15
16
17
18
19


def forward(x, weight):
    out = torch.matmul(x, weight)
    with seed(ParallelMode.DATA):
        out_ = F.dropout(out, p=0.4, training=True)
    return out_

20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def forward_inplace_ckpt(x, weight, cpu_offload=False):
    out = torch.matmul(x, weight)
    bn = torch.nn.BatchNorm1d(4, affine=False)
    bn = bn.to(device="cuda")
    out = bn(out)

    def ckpt0(x):
        return F.relu(x, inplace=True)

    out = checkpoint(ckpt0, cpu_offload, out, use_reentrant=False)
    return out


def forward_inplace(x, weight):
    out = torch.matmul(x, weight)
    bn = torch.nn.BatchNorm1d(4, affine=False)
    bn = bn.to(device="cuda")
    out = bn(out)
    out = F.relu(out, inplace=True)
    return out


43
44
45
@clear_cache_before_run()
@parameterize("use_reentrant", [True, False])
@parameterize("cpu_offload", [True, False])
46
47
48
49
50
51
def test_activation_checkpointing(cpu_offload, use_reentrant):

    # as seed manager is singleton
    # if we don't reset seeds here,
    # other tests might affect this test
    reset_seeds()
HELSON's avatar
HELSON committed
52

53
    # We put initialization here to avoid change cuda rng state below
HELSON's avatar
HELSON committed
54
55
56
57
58
59
60
61
62
    inputs = torch.rand(2, 2, requires_grad=True, device='cuda')
    weight = torch.rand(2, 4, requires_grad=True, device='cuda')

    # Get a copy of input tensors
    inputs_ = torch.empty(2, 2, requires_grad=True, device='cuda')
    inputs_.data.copy_(inputs.data)
    weight_ = torch.empty(2, 4, requires_grad=True, device='cuda')
    weight_.data.copy_(weight.data)

63
64
    add_seed(ParallelMode.GLOBAL, 1024)
    add_seed(ParallelMode.DATA, 1026)
zbian's avatar
zbian committed
65
66
67
68
69
70
    set_mode(ParallelMode.GLOBAL)
    global_cuda_rng_state = torch.cuda.get_rng_state()
    set_mode(ParallelMode.DATA)
    data_parallel_cuda_rng_state = torch.cuda.get_rng_state()
    set_mode(ParallelMode.GLOBAL)

HELSON's avatar
HELSON committed
71
    out = forward(inputs, weight)
zbian's avatar
zbian committed
72
73
74
    loss = out.sum()
    loss.backward()

HELSON's avatar
HELSON committed
75
    # Recover cuda rng states
zbian's avatar
zbian committed
76
77
78
79
80
    set_mode(ParallelMode.GLOBAL)
    torch.cuda.set_rng_state(global_cuda_rng_state)
    set_mode(ParallelMode.DATA)
    torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
    set_mode(ParallelMode.GLOBAL)
HELSON's avatar
HELSON committed
81

82
    out = checkpoint(forward, cpu_offload, inputs_, weight_, use_reentrant=use_reentrant)
zbian's avatar
zbian committed
83
84
85
    loss = out.sum()
    loss.backward()

HELSON's avatar
HELSON committed
86
    assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match'
Frank Lee's avatar
Frank Lee committed
87
    torch.cuda.empty_cache()
88

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    # Extra test for use_reentrant=False
    if use_reentrant == False:
        # Recover cuda rng states
        set_mode(ParallelMode.GLOBAL)
        torch.cuda.set_rng_state(global_cuda_rng_state)
        set_mode(ParallelMode.DATA)
        torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
        set_mode(ParallelMode.GLOBAL)

        out = forward_inplace(inputs, weight)
        loss = out.sum()
        loss.backward()

        # Recover cuda rng states
        set_mode(ParallelMode.GLOBAL)
        torch.cuda.set_rng_state(global_cuda_rng_state)
        set_mode(ParallelMode.DATA)
        torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
        set_mode(ParallelMode.GLOBAL)

        out = forward_inplace_ckpt(inputs_, weight_, cpu_offload=cpu_offload)
        loss = out.sum()
        loss.backward()

        assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match'
        torch.cuda.empty_cache()

116
117
118
119
120
    # as seed manager is singleton
    # if we don't reset seeds here,
    # other tests will fail if running together with this test
    # as other tests can't overwrite the seed set by this test
    reset_seeds()
121
122
123
124


if __name__ == "__main__":
    test_activation_checkpointing(False, False)