test_mevo.py 3.23 KB
Newer Older
Min Xu's avatar
Min Xu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

import os

import pytest
import torch

15
from fair_dev.testing.testing import skip_if_no_cuda
Min Xu's avatar
Min Xu committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from fairscale.experimental.nn import MEVO
from fairscale.experimental.nn.mevo import BaselineSoftmaxNllLoss, get_data


@pytest.fixture(scope="session", params=[torch.float16, torch.float32])
def input_data(request):
    shape = ((2, 3), (3, 4))
    return get_data(shape, dtype=request.param)


_dense_out = {}  # type: ignore
_dense_grad = {}  # type: ignore


Min Xu's avatar
Min Xu committed
30
31
32
33
34
35
36
37
38
39
40
@skip_if_no_cuda
def test_mevo_eval():
    """Test eval mode without target tensor"""
    weight = torch.nn.Linear(3, 4).cuda().weight
    input = torch.rand(1, 5, 3).cuda()
    k = MEVO(weight)
    k.eval()
    out = k(input, None)
    assert out.shape == (1, 5, 4)


41
42
# Note for the lmcl_scale, overly large value, like 64 for small shape input
# will cause inf/nan in mevo. Larger scale value is only needed for large shape inputs.
Min Xu's avatar
Min Xu committed
43
@skip_if_no_cuda
44
45
46
47
@pytest.mark.parametrize("lmcl_scale", [None, 8])
def test_mevo(lmcl_scale):
    """Test the MEVO kernel in a single process (no DDP/FSDP)."""
    # Set seed and reset peak mem so that peak measure below is correct.
Min Xu's avatar
Min Xu committed
48
    torch.random.manual_seed(os.getpid())
49
    torch.cuda.reset_peak_memory_stats()
Min Xu's avatar
Min Xu committed
50
51
52
53
54
55
56
57
    shape = ((5, 3), (3, 7))
    # Turn on large data for local testing.
    large = False
    if large:
        shape = ((1 * 2048, 4096), (4096, 256008))
    print("\nshapes are", shape)

    input, weight, target = get_data(shape, dtype=torch.float16)
58
    k = MEVO(weight, tile_factor=16, scale=lmcl_scale)
Min Xu's avatar
Min Xu committed
59
60
61

    o = k(input, target)
    o.backward()
62
    print("MEVO loss", o, o.shape)
Min Xu's avatar
Min Xu committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    del o

    cur_mem = round(torch.cuda.memory_allocated() / 1024 / 1024)
    mem = round(torch.cuda.max_memory_allocated() / 1024 / 1024)
    print("cur and peak mem for tiled fwd+bwd =", cur_mem, mem)

    assert input.shape == input.grad.shape
    input_data = input.data.cpu()
    input_grad1 = input.grad.cpu()
    del input

    cur_mem = round(torch.cuda.memory_allocated() / 1024 / 1024)
    mem = round(torch.cuda.max_memory_allocated() / 1024 / 1024)
    print("after moving input and its grad, cur and peak mem for tiled fwd+bwd =", cur_mem, mem)

78
    print("MEVO grad norm and grad", weight.grad.norm(), weight.grad)
Min Xu's avatar
Min Xu committed
79
80
81
82
    g1 = weight.grad.clone()
    weight.grad = None

    input = input_data.cuda().requires_grad_(True)
83
    refk = BaselineSoftmaxNllLoss(weight, scale=lmcl_scale)
Min Xu's avatar
Min Xu committed
84
85
    o = refk(input, target)
    o.backward()
86
    print("Reference loss", o, o.shape)
Min Xu's avatar
Min Xu committed
87
    del o
88
    print("Reference grad norm and grad", weight.grad.norm(), weight.grad)
Min Xu's avatar
Min Xu committed
89
90
91
    g2 = weight.grad.clone()
    input_grad2 = input.grad.cpu()

92
93
    # Print the diff. We use .cuda() since in torch 1.7 and 1.8, min() and max() are not
    # implemented for cpu float16. The diff should in general be below 0.01 in magnitude.
Min Xu's avatar
Min Xu committed
94
95
96
97
    diff = g1 - g2
    print("weight grad diff", diff.cuda().min(), diff.cuda().max())
    diff = input_grad1 - input_grad2
    print("input grad diff", diff.cuda().min(), diff.cuda().max())