test_norm_gradient_clipping.py 2.92 KB
Newer Older
1
2
import pytest
import torch
3
from torch.nn.parameter import Parameter
4
from torch.nn.utils import clip_grad_norm_
5
6
7
8
9
10
11

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.tensor import ColoTensorSpec, ProcessGroup, distspec
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
12
13
14
15
16
17
18
19
20
from colossalai.utils.common import clip_grad_norm


def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8):
    return abs(num - other) <= atol + rtol * other


def shard_param(p: ColoParameter) -> None:
    pg = p.get_process_group()
21
    p._redistribute(distspec.ShardSpec([0], [pg.tp_world_size()]))
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach()


def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None:
    pg = colo_p.get_process_group()
    if p.shape != colo_p.shape:
        grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()]
    else:
        grad = p.grad
    assert torch.allclose(grad, colo_p.grad), f'diff: {torch.abs(grad - colo_p.grad)}'


@parameterize('dtype', [torch.float])
@parameterize('device', ['mixed', 'cuda', 'cpu'])
@parameterize('norm_type', [2.0, 3.0, float('inf')])
def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float):
    print(f'{world_size}, {dtype}, {device}, {norm_type}')
    cuda_device = get_current_device()
    devices = [cuda_device] * 4
    if device == 'cpu':
        devices = [torch.device('cpu')] * 4
    elif device == 'mixed':
        devices = [cuda_device] * 2 + [torch.device('cpu')] * 2
    pg = ProcessGroup(tp_degree=world_size)
    params = [Parameter(torch.empty(4, 4, dtype=dtype, device=devices[i])) for i in range(4)]
    colo_params = [
        ColoParameter(torch.empty(4, 4, dtype=dtype, device=devices[i]), spec=ColoTensorSpec(pg)) for i in range(4)
    ]
    for p, colo_p in zip(params, colo_params):
        grad = torch.rand_like(p)
        p.grad = grad
        colo_p.grad = grad.clone().detach()
    shard_param(colo_params[0])
    shard_param(colo_params[2])
    torch_norm = clip_grad_norm_(params, 1.0, norm_type=norm_type)
    colo_norm = clip_grad_norm(colo_params, 1.0, norm_type=norm_type)
    assert close(torch_norm, colo_norm), f'diff: {abs(torch_norm-colo_norm)}'
    for p, colo_p in zip(params, colo_params):
        check_grad_equal(p, colo_p)


def run_dist(rank, world_size, port):
    disable_existing_loggers()
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    run_grad_clip_norm(world_size=world_size)


69
@pytest.mark.skip("this need to be updated")
70
71
72
73
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_zero_clip_grad(world_size: int):
74
    spawn(run_dist, world_size)
75
76
77
78


if __name__ == '__main__':
    test_zero_clip_grad(2)