test_adam_optim.py 3.07 KB
Newer Older
Hongxin Liu's avatar
Hongxin Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
from copy import deepcopy
from typing import Type, Union

import pytest
import torch
import torch.nn as nn
from torch.optim import Adam, AdamW

from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
from tests.kit.model_zoo import model_zoo

_ALLOWED_OPTIM_DEVICES = [
13
14
15
16
17
    (FusedAdam, torch.device("cuda:0")),
    (CPUAdam, torch.device("cpu")),
    (CPUAdam, torch.device("cuda:0")),
    (HybridAdam, torch.device("cpu")),
    (HybridAdam, torch.device("cuda:0")),
Hongxin Liu's avatar
Hongxin Liu committed
18
19
20
]

_ALLOWED_P_G_TYPES = [
21
22
23
    (torch.float, torch.float),  # pure fp32
    (torch.float, torch.half),  # fp16 amp
    (torch.float, torch.bfloat16),  # bfloat16 amp
Hongxin Liu's avatar
Hongxin Liu committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
]

N_STEPS = 3


def setup_param_groups(bert_model: nn.Module) -> list:
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.1,
        },
        {
            "params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    return optimizer_grouped_parameters


def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None:
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
        torch_p.grad = torch.rand_like(torch_p)
        # avoid inconsistent grad and param dtype error
        orig_p = p.data
        p.data = torch_p.grad.clone().to(g_dtype)
        p.grad = p.data
        p.data = orig_p


54
55
56
57
58
59
60
61
62
63
64
@pytest.mark.parametrize("optim_cls, device", _ALLOWED_OPTIM_DEVICES)
@pytest.mark.parametrize("adamw", [False, True])
@pytest.mark.parametrize("p_dtype, g_dtype", _ALLOWED_P_G_TYPES)
def test_adam_optim_on_bert(
    optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]],
    device: torch.device,
    adamw: bool,
    p_dtype: torch.dtype,
    g_dtype: torch.dtype,
) -> None:
    model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_bert_for_sequence_classification").values()))
Hongxin Liu's avatar
Hongxin Liu committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    torch_model = model_fn().to(device)
    model = deepcopy(torch_model).to(p_dtype)
    lr = 1e-3
    beta1, beta2 = 0.9, 0.999
    eps = 1e-8
    torch_optim_cls = AdamW if adamw else Adam
    torch_optim = torch_optim_cls(setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps)
    optim = optim_cls(setup_param_groups(model), lr=lr, betas=(beta1, beta2), eps=eps, adamw_mode=adamw)

    rtol, atol = 1e-5, 1e-5
    if p_dtype is torch.float16 or g_dtype is torch.float16:
        rtol, atol = 2e-3, 2e-3
    if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
        rtol, atol = 4e-3, 4e-3

    for _ in range(N_STEPS):
        set_grad(model, torch_model, g_dtype)
        torch_optim.step()
        optim.step()
        torch_optim.zero_grad()
        optim.zero_grad()
        for p, torch_p in zip(model.parameters(), torch_model.parameters()):
            # if overflow, the weight won't be updated. so there will be no nan in p
            assert not torch.isnan(p).any()
            assert torch.allclose(p.float(), torch_p, rtol=rtol, atol=atol)