test_cpu_adam.py 2.1 KB
Newer Older
Jeff Rasley's avatar
Jeff Rasley committed
1
2
3
4
5
6
7
import argparse
import torch
import time
import numpy as np
import pytest
import copy

8
import deepspeed
9
10
11
12
13
from deepspeed.ops.adam import FusedAdam
from deepspeed.ops.op_builder import CPUAdamBuilder

if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
    pytest.skip("cpu-adam is not compatible")
14

Jeff Rasley's avatar
Jeff Rasley committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

def check_equal(first, second, atol=1e-2, verbose=False):
    x = first.detach().numpy()
    y = second.detach().numpy()
    if verbose:
        print("x = {}".format(x.flatten()))
        print("y = {}".format(y.flatten()))
        print('-' * 80)
    np.testing.assert_allclose(x, y, err_msg="param-update dismatch!", atol=atol)

@pytest.mark.parametrize('model_size',
                         [
                             (64),
                             (22),
                             (55),
                             (127),
                             (1024),
                             (1048576),
                         ]) # yapf: disable
Shaden Smith's avatar
Shaden Smith committed
34
def test_cpu_adam_opt(model_size):
35
    from deepspeed.ops.adam import DeepSpeedCPUAdam
Jeff Rasley's avatar
Jeff Rasley committed
36
37
38
39
40
    device = 'cpu'
    rng_state = torch.get_rng_state()
    param = torch.nn.Parameter(torch.randn(model_size, device=device))
    torch.set_rng_state(rng_state)
    param1 = torch.nn.Parameter(torch.randn(model_size, device=device))
41
42
43
    torch.set_rng_state(rng_state)
    param2_data = torch.randn(model_size, device=device).cuda()
    param2 = torch.nn.Parameter(param2_data)
Jeff Rasley's avatar
Jeff Rasley committed
44

45
    optimizer1 = torch.optim.AdamW([param1])
46
    optimizer2 = FusedAdam([param2])
Jeff Rasley's avatar
Jeff Rasley committed
47
48
49
50
51
52
53
    optimizer = DeepSpeedCPUAdam([param])

    for i in range(10):
        rng_state = torch.get_rng_state()
        param.grad = torch.randn(model_size, device=device)
        torch.set_rng_state(rng_state)
        param1.grad = torch.randn(model_size, device=device)
54
55
        torch.set_rng_state(rng_state)
        param2.grad = torch.randn(model_size, device=device).cuda()
Jeff Rasley's avatar
Jeff Rasley committed
56
57

        optimizer.step()
58
        optimizer2.step()
Jeff Rasley's avatar
Jeff Rasley committed
59
60
61
        optimizer1.step()

    check_equal(param, param1, atol=1e-2, verbose=True)
62
    check_equal(param, param2.cpu(), atol=1e-2, verbose=True)