test_mevo.py 2.42 KB
Newer Older
Min Xu's avatar
Min Xu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

import os

import pytest
import torch

from fairscale.experimental.nn import MEVO
from fairscale.experimental.nn.mevo import BaselineSoftmaxNllLoss, get_data
from fairscale.utils.testing import skip_if_no_cuda


@pytest.fixture(scope="session", params=[torch.float16, torch.float32])
def input_data(request):
    shape = ((2, 3), (3, 4))
    return get_data(shape, dtype=request.param)


_dense_out = {}  # type: ignore
_dense_grad = {}  # type: ignore


@skip_if_no_cuda
def test_mevo():
    """Test the MEVO kernel by itself."""
    torch.random.manual_seed(os.getpid())
    shape = ((5, 3), (3, 7))
    # Turn on large data for local testing.
    large = False
    if large:
        shape = ((1 * 2048, 4096), (4096, 256008))
    print("\nshapes are", shape)

    input, weight, target = get_data(shape, dtype=torch.float16)
    k = MEVO(weight, tile_factor=16)

    o = k(input, target)
    o.backward()
    print(o, o.shape)
    del o

    cur_mem = round(torch.cuda.memory_allocated() / 1024 / 1024)
    mem = round(torch.cuda.max_memory_allocated() / 1024 / 1024)
    print("cur and peak mem for tiled fwd+bwd =", cur_mem, mem)

    assert input.shape == input.grad.shape
    input_data = input.data.cpu()
    input_grad1 = input.grad.cpu()
    del input

    cur_mem = round(torch.cuda.memory_allocated() / 1024 / 1024)
    mem = round(torch.cuda.max_memory_allocated() / 1024 / 1024)
    print("after moving input and its grad, cur and peak mem for tiled fwd+bwd =", cur_mem, mem)

    print(weight.grad.norm(), weight.grad)
    g1 = weight.grad.clone()
    weight.grad = None

    input = input_data.cuda().requires_grad_(True)
    refk = BaselineSoftmaxNllLoss(weight)
    o = refk(input, target)
    o.backward()
    print(o, o.shape)
    del o
    print(weight.grad.norm(), weight.grad)
    g2 = weight.grad.clone()
    input_grad2 = input.grad.cpu()

    # Print the diff. We use .cuda() since in 1.7 and 1.8, min() and max() are not
    # implemented for cpu float16.
    diff = g1 - g2
    print("weight grad diff", diff.cuda().min(), diff.cuda().max())
    diff = input_grad1 - input_grad2
    print("input grad diff", diff.cuda().min(), diff.cuda().max())