test_zero_gradient_clippling.py 3.38 KB
Newer Older
Jiarui Fang's avatar
Jiarui Fang committed
1
2
3
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

4
from functools import partial
Jiarui Fang's avatar
Jiarui Fang committed
5
6
7
8
9
10
11

import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
12
13
14

import colossalai
from colossalai.logging import disable_existing_loggers
15
16
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import checkpoint, clip_grad_norm_fp32
17
18
from colossalai.zero.legacy.shard_utils.tensor_shard_strategy import TensorShardStrategy
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
Jiarui Fang's avatar
Jiarui Fang committed
19
20
21
22


def checkpoint_wrapper(module, enable=True):
    if enable:
23
        module.forward = partial(checkpoint, module.forward, False)
Jiarui Fang's avatar
Jiarui Fang committed
24
25
26
27
    return module


class Net(nn.Module):
28

Jiarui Fang's avatar
Jiarui Fang committed
29
30
31
32
33
34
35
    def __init__(self, checkpoint=False) -> None:
        super().__init__()
        self.fc1 = nn.Linear(5, 5)
        self.fc2 = nn.Linear(5, 5)
        self.fc3 = nn.Linear(5, 1)
        if checkpoint:
            self.fc1 = checkpoint_wrapper(self.fc1)
36
        self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
Jiarui Fang's avatar
Jiarui Fang committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0):
    model.train()
    optimizer.zero_grad()
    with torch.cuda.amp.autocast(enabled=enable_autocast):
        y = model(x)
        loss = y.sum()
    loss = loss.float()
    loss.backward()
    clip_grad(model, norm_type)
    optimizer.step()


def clip_grad(model, norm_type):
    if isinstance(model, DDP):
        clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type)
    else:
        clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type)


def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
    if loose:
        return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
    return torch.allclose(tensor_a, tensor_b)


def check_grads(model, zero_model, loose=False):
    rank = dist.get_rank()
    for p, zero_p in zip(model.parameters(), zero_model.parameters()):
        zero_grad = zero_p.grad.clone().to(p.device)
        chunks = torch.flatten(p.grad).chunk(4)
        if rank >= len(chunks):
            continue
        grad = chunks[rank]
        if zero_p.zero_shard_padding > 0:
            zero_grad = zero_grad[:-zero_p.zero_shard_padding]
        assert grad.dtype == zero_grad.dtype
        assert allclose(grad, zero_grad, loose=loose)


def check_params(model, zero_model, loose=False):
    rank = dist.get_rank()
    for p, zero_p in zip(model.parameters(), zero_model.parameters()):
        zero_shard_padding = zero_p.zero_shard_padding
        zero_p = zero_p.clone().to(p.device)
        chunks = torch.flatten(p).chunk(4)
        if rank >= len(chunks):
            continue
        p = chunks[rank]
        if zero_shard_padding > 0:
            zero_p = zero_p[:-zero_shard_padding]
        assert p.dtype == zero_p.dtype
        assert allclose(p, zero_p, loose=loose)


def run_dist(rank, world_size, port):
    disable_existing_loggers()
100
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
Jiarui Fang's avatar
Jiarui Fang committed
101
102


103
@pytest.mark.dist
104
@rerun_if_address_is_in_use()
Jiarui Fang's avatar
Jiarui Fang committed
105
106
def test_zero_clip_grad():
    world_size = 4
107
    spawn(run_dist, world_size)
Jiarui Fang's avatar
Jiarui Fang committed
108
109
110
111


if __name__ == '__main__':
    test_zero_clip_grad()