test_zero_gradient_clippling.py 3.53 KB
Newer Older
Jiarui Fang's avatar
Jiarui Fang committed
1
2
3
4
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import copy
5
from functools import partial
Jiarui Fang's avatar
Jiarui Fang committed
6
7
8
9
10
11
12
13

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
14
15
16

import colossalai
from colossalai.logging import disable_existing_loggers
17
from colossalai.testing import parameterize, rerun_if_address_is_in_use
18
19
20
from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
from colossalai.zero.legacy.shard_utils.tensor_shard_strategy import TensorShardStrategy
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
Jiarui Fang's avatar
Jiarui Fang committed
21
22
23
24


def checkpoint_wrapper(module, enable=True):
    if enable:
25
        module.forward = partial(checkpoint, module.forward, False)
Jiarui Fang's avatar
Jiarui Fang committed
26
27
28
29
    return module


class Net(nn.Module):
30

Jiarui Fang's avatar
Jiarui Fang committed
31
32
33
34
35
36
37
    def __init__(self, checkpoint=False) -> None:
        super().__init__()
        self.fc1 = nn.Linear(5, 5)
        self.fc2 = nn.Linear(5, 5)
        self.fc3 = nn.Linear(5, 1)
        if checkpoint:
            self.fc1 = checkpoint_wrapper(self.fc1)
38
        self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
Jiarui Fang's avatar
Jiarui Fang committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0):
    model.train()
    optimizer.zero_grad()
    with torch.cuda.amp.autocast(enabled=enable_autocast):
        y = model(x)
        loss = y.sum()
    loss = loss.float()
    loss.backward()
    clip_grad(model, norm_type)
    optimizer.step()


def clip_grad(model, norm_type):
    if isinstance(model, DDP):
        clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type)
    else:
        clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type)


def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
    if loose:
        return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
    return torch.allclose(tensor_a, tensor_b)


def check_grads(model, zero_model, loose=False):
    rank = dist.get_rank()
    for p, zero_p in zip(model.parameters(), zero_model.parameters()):
        zero_grad = zero_p.grad.clone().to(p.device)
        chunks = torch.flatten(p.grad).chunk(4)
        if rank >= len(chunks):
            continue
        grad = chunks[rank]
        if zero_p.zero_shard_padding > 0:
            zero_grad = zero_grad[:-zero_p.zero_shard_padding]
        assert grad.dtype == zero_grad.dtype
        assert allclose(grad, zero_grad, loose=loose)


def check_params(model, zero_model, loose=False):
    rank = dist.get_rank()
    for p, zero_p in zip(model.parameters(), zero_model.parameters()):
        zero_shard_padding = zero_p.zero_shard_padding
        zero_p = zero_p.clone().to(p.device)
        chunks = torch.flatten(p).chunk(4)
        if rank >= len(chunks):
            continue
        p = chunks[rank]
        if zero_shard_padding > 0:
            zero_p = zero_p[:-zero_shard_padding]
        assert p.dtype == zero_p.dtype
        assert allclose(p, zero_p, loose=loose)


def run_dist(rank, world_size, port):
    disable_existing_loggers()
102
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
Jiarui Fang's avatar
Jiarui Fang committed
103
104


105
@pytest.mark.dist
106
@rerun_if_address_is_in_use()
Jiarui Fang's avatar
Jiarui Fang committed
107
108
109
110
111
112
113
114
def test_zero_clip_grad():
    world_size = 4
    run_func = partial(run_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_zero_clip_grad()