test_fused_adam.py 2.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import torch
import torch.nn as nn
from torch.optim.adam import Adam
from torch.optim import AdamW

from colossalai.nn.optimizer.fused_adam import FusedAdam
from colossalai.testing import parameterize


class FC(nn.Module):
11

12
13
14
    def __init__(self) -> None:
        super().__init__()
        self.fc = nn.Sequential(nn.Linear(64, 64))
15

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    def forward(self, x):
        return self.fc(x)


@parameterize('adamw', [False, True])
@parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half])
def test_adam(adamw, p_dtype, g_dtype):
    model = FC().cuda().to(p_dtype)
    state = model.state_dict()
    model_copy = FC().cuda().to(p_dtype)
    model_copy.load_state_dict(state.copy())

    if adamw:
        optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=True)
        torch_optim = AdamW(model_copy.parameters(), lr=1e-3)
    else:
        optim = FusedAdam(model.parameters(), lr=1e-3)
        torch_optim = Adam(model_copy.parameters(), lr=1e-3)

    data = torch.rand(1024, 64).cuda().to(p_dtype)
    data_copy = data.clone()
    label = torch.rand(1024, 64).cuda().to(p_dtype)

    for d, l in zip(data, label):
        y = model(d)
42
        loss = ((l - y)**2).sum()
43
44
45
46
47
48
49
50
51
        optim.zero_grad()
        loss.backward()
        if p_dtype != g_dtype:
            for i in range(len(optim.param_groups[0]['params'])):
                optim.param_groups[0]['params'][i].grad.data = optim.param_groups[0]['params'][i].grad.data.to(g_dtype)
        optim.step()

    for d, l in zip(data_copy, label):
        y = model_copy(d)
52
        loss = ((l - y)**2).sum()
53
54
55
56
57
        torch_optim.zero_grad()
        loss.backward()
        torch_optim.step()

    assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params'])
58

59
60
61
62
63
    for i in range(len(optim.param_groups[0]['params'])):
        if torch.isnan(optim.param_groups[0]['params'][i]).any() \
           or torch.isnan(torch_optim.param_groups[0]['params'][i]).any():
            continue
        assert torch.allclose(optim.param_groups[0]['params'][i], torch_optim.param_groups[0]['params'][i], 2e-3, 2e-3)