test_zero_gradient_clippling.py 3.57 KB
Newer Older
Jiarui Fang's avatar
Jiarui Fang committed
1
2
3
4
5
6
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import copy

import colossalai
7
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
Jiarui Fang's avatar
Jiarui Fang committed
8
9
10
11
12
13
14
15
16
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.logging import disable_existing_loggers
from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
17
18
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
from functools import partial
19
from colossalai.testing import parameterize, rerun_on_exception
Jiarui Fang's avatar
Jiarui Fang committed
20
21
22
23


def checkpoint_wrapper(module, enable=True):
    if enable:
24
        module.forward = partial(checkpoint, module.forward, False)
Jiarui Fang's avatar
Jiarui Fang committed
25
26
27
28
    return module


class Net(nn.Module):
29

Jiarui Fang's avatar
Jiarui Fang committed
30
31
32
33
34
35
36
    def __init__(self, checkpoint=False) -> None:
        super().__init__()
        self.fc1 = nn.Linear(5, 5)
        self.fc2 = nn.Linear(5, 5)
        self.fc3 = nn.Linear(5, 1)
        if checkpoint:
            self.fc1 = checkpoint_wrapper(self.fc1)
37
        self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
Jiarui Fang's avatar
Jiarui Fang committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0):
    model.train()
    optimizer.zero_grad()
    with torch.cuda.amp.autocast(enabled=enable_autocast):
        y = model(x)
        loss = y.sum()
    loss = loss.float()
    loss.backward()
    clip_grad(model, norm_type)
    optimizer.step()


def clip_grad(model, norm_type):
    if isinstance(model, DDP):
        clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type)
    else:
        clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type)


def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
    if loose:
        return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
    return torch.allclose(tensor_a, tensor_b)


def check_grads(model, zero_model, loose=False):
    rank = dist.get_rank()
    for p, zero_p in zip(model.parameters(), zero_model.parameters()):
        zero_grad = zero_p.grad.clone().to(p.device)
        chunks = torch.flatten(p.grad).chunk(4)
        if rank >= len(chunks):
            continue
        grad = chunks[rank]
        if zero_p.zero_shard_padding > 0:
            zero_grad = zero_grad[:-zero_p.zero_shard_padding]
        assert grad.dtype == zero_grad.dtype
        assert allclose(grad, zero_grad, loose=loose)


def check_params(model, zero_model, loose=False):
    rank = dist.get_rank()
    for p, zero_p in zip(model.parameters(), zero_model.parameters()):
        zero_shard_padding = zero_p.zero_shard_padding
        zero_p = zero_p.clone().to(p.device)
        chunks = torch.flatten(p).chunk(4)
        if rank >= len(chunks):
            continue
        p = chunks[rank]
        if zero_shard_padding > 0:
            zero_p = zero_p[:-zero_shard_padding]
        assert p.dtype == zero_p.dtype
        assert allclose(p, zero_p, loose=loose)


def run_dist(rank, world_size, port):
    disable_existing_loggers()
101
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
Jiarui Fang's avatar
Jiarui Fang committed
102
103


104
@pytest.mark.dist
105
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
Jiarui Fang's avatar
Jiarui Fang committed
106
107
108
109
110
111
112
113
def test_zero_clip_grad():
    world_size = 4
    run_func = partial(run_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_zero_clip_grad()