unittest_cpu_adam.py 3.85 KB
Newer Older
LuGY's avatar
LuGY committed
1
2
import math
import torch
3
4

from colossalai.testing import parameterize
LuGY's avatar
LuGY committed
5

6

LuGY's avatar
LuGY committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def torch_adam_update(
    step,
    lr,
    beta1,
    beta2,
    eps,
    weight_decay,
    param,
    grad,
    exp_avg,
    exp_avg_sq,
    loss_scale,
    use_adamw,
):
    if loss_scale > 0:
        grad.div_(loss_scale)
23
24
    bias_correction1 = 1 - beta1**step
    bias_correction2 = 1 - beta2**step
LuGY's avatar
LuGY committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

    if weight_decay != 0:
        if use_adamw:
            # Perform stepweight decay
            param.mul_(1 - lr * weight_decay)
        else:
            grad = grad.add(param, alpha=weight_decay)

    # Decay the first and second moment running average coefficient
    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
    denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)

    step_size = lr / bias_correction1

    param.addcdiv_(exp_avg, denom, value=-step_size)


43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def assertLess(data_diff, threshold, msg):
    assert data_diff < threshold, msg


def assertTrue(condition, msg):
    assert condition, msg


@parameterize('adamw', [True, False])
@parameterize('step', [1, 2])
@parameterize('loss_scale', [-1, 2 ** 5])
@parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half])
def test_cpu_adam(adamw, step, loss_scale, p_dtype, g_dtype):
    lr = 1e-3
    beta1, beta2 = 0.9, 0.999
    eps = 1e-8
    weight_decay = 0
    
    for i in range(1024):
        p_data = torch.rand(64, dtype=p_dtype)
LuGY's avatar
LuGY committed
64
        p_data_copy = p_data.clone().float()
65
        p_grad = torch.rand(64, dtype=g_dtype)
LuGY's avatar
LuGY committed
66
67
68
        if loss_scale > 0:
            p_grad.mul_(loss_scale)
        p_grad_copy = p_grad.clone().float()
69
        exp_avg = torch.rand(p_data.shape)
LuGY's avatar
LuGY committed
70
        exp_avg_copy = exp_avg.clone()
71
        exp_avg_sq = torch.rand(p_data.shape)
LuGY's avatar
LuGY committed
72
73
        exp_avg_sq_copy = exp_avg_sq.clone()

74
75
76
77
78
79
80
        try:
            import cpu_adam
            cpu_adam_op = cpu_adam
        except:
            raise ImportError("...")

        cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, adamw, False)
LuGY's avatar
LuGY committed
81
        cpu_adam_op.adam_update(
82
            0,
LuGY's avatar
LuGY committed
83
84
85
86
87
88
89
            step,
            lr,
            beta1,
            beta2,
            eps,
            weight_decay,
            True,
90
91
            p_data.view(-1),    # fp32 data
            p_grad.view(-1),    # fp32 grad
LuGY's avatar
LuGY committed
92
93
94
95
96
97
98
99
100
101
102
103
            exp_avg.view(-1),
            exp_avg_sq.view(-1),
            loss_scale,
        )

        torch_adam_update(
            step,
            lr,
            beta1,
            beta2,
            eps,
            weight_decay,
104
105
            p_data_copy,    # fp32 data
            p_grad_copy,    # fp32 grad
LuGY's avatar
LuGY committed
106
107
108
            exp_avg_copy,
            exp_avg_sq_copy,
            loss_scale,
109
            adamw,
LuGY's avatar
LuGY committed
110
111
112
113
114
        )
        if loss_scale > 0:
            p_grad.div_(loss_scale)
        var = p_data_copy - p_data
        data_diff = torch.max(torch.abs(var))
115
116
117
118
        threshold = 1e-3
        print(f"p_data diff {data_diff}. failed check, step {step}, lr {lr} eps "
            f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}")
        assertLess(
LuGY's avatar
LuGY committed
119
120
            data_diff,
            threshold,
121
122
            f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, loss_scale {loss_scale}, eps "
            f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}",
LuGY's avatar
LuGY committed
123
124
        )
        max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad))
125
        assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}")
LuGY's avatar
LuGY committed
126
        max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
127
        assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
LuGY's avatar
LuGY committed
128
        max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
129
        assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")