test_autocast.py 2.18 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import pytest
import torch
import deepspeed
from deepspeed.runtime.zero.linear import LinearModuleForZeroStage3


def _skip_autocast_test():
    try:
        from torch.cuda.amp import custom_fwd, custom_bwd
    except (ImportError, AttributeError) as exp:
        return True

    return False


@pytest.mark.parametrize('half_op', [False, True])
def test_missing_amp_autocast(tmpdir, half_op):
    hidden_dim = 4
    if half_op:
        input = torch.randn(hidden_dim).cuda().half()
        ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda().half()
    else:
        input = torch.randn(hidden_dim).cuda()
        ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda()

    output = ds_linear(input)
    assert output.dtype == ds_linear.weight.dtype


@pytest.mark.parametrize('half_op', [False, True])
def test_disable_autocast_linear(tmpdir, half_op):
    if _skip_autocast_test():
        pytest.skip("amp autocast is not available")

    hidden_dim = 4
    if half_op:
        input = torch.randn(hidden_dim).cuda().half()
        ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda().half()
    else:
        input = torch.randn(hidden_dim).cuda()
        ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda()

    with torch.cuda.amp.autocast(False):
        output = ds_linear(input)
        assert output.dtype == ds_linear.weight.dtype


@pytest.mark.parametrize('half_input, half_weight',
                         [(False,
                           False),
                          (False,
                           True),
                          (True,
                           False),
                          (True,
                           True)])
def test_autocast_linear(tmpdir, half_input, half_weight):
    if _skip_autocast_test():
        pytest.skip("amp autocast is not available")

    hidden_dim = 4
    input = torch.randn(hidden_dim).cuda()
    ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda()

    if half_input:
        input = input.half()

    if half_weight:
        ds_linear = ds_linear.half()

    with torch.cuda.amp.autocast():
        output = ds_linear(input)
        assert output.dtype == torch.half