test_containers.py 3.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test utility classes from containers.py. """

import random

import pytest
import torch

from fairscale.utils.containers import (
    apply_to_tensors,
    pack_kwargs,
    split_non_tensors,
    unpack_kwargs,
    unpack_non_tensors,
)


@pytest.mark.parametrize("devices", [["cpu"], ["cuda"], ["cpu", "cuda"]])
def test_apply_to_tensors(devices):
    """Test apply_to_tensors for both cpu & gpu"""
    if "cuda" in devices and not torch.cuda.is_available() or torch.cuda.device_count() < 1:
        pytest.skip("Skipped due to lack of GPU")
    expected = 0

    def get_a_tensor():
        """Return a random tensor on random device."""
        dev = random.choice(devices)
        shape = random.choice(((1), (2, 3), (4, 5, 6), (7, 8, 9, 10)))
        t = torch.rand(shape).to(dev)
        nonlocal expected
        expected += t.numel()
        return t

    # create a mixed bag of data.
    data = [1, "str"]
    data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3})
    data.insert(0, set(["x", get_a_tensor(), get_a_tensor()]))
    data.append(([1], get_a_tensor(), (1), [get_a_tensor()], set((1, 2))))

    total = 0

    def fn(t, x=[[total]]):
        nonlocal total
        total += t.numel()
        return t

    apply_to_tensors(fn, data)
    assert total == expected, f"{total} vs. {expected}"


def test_pack_unpack():
    """Test pack_kwargs and unpack_kwargs."""
    kwarg_keys, flat_args = pack_kwargs(1, 2, 3, 4)
    assert kwarg_keys == tuple()
    assert flat_args == (1, 2, 3, 4)

    kwarg_keys, flat_args = pack_kwargs(a=1, b={2: "2"}, c={3}, d=[4], e=(5,))
    assert kwarg_keys == ("a", "b", "c", "d", "e")
    assert flat_args == (1, {2: "2"}, {3}, [4], (5,))

    kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4)
    assert kwarg_keys == ("a", "b")
    assert flat_args == (1, 2, 3, 4)

    args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
    assert args == (1, 2)
    assert kwargs == {"a": 3, "b": 4}

    args, kwargs = unpack_kwargs([], flat_args)
    assert kwargs == {}
    assert args == (1, 2, 3, 4)

    args, kwargs = unpack_kwargs(["a", "b", "c", "d"], flat_args)
    assert kwargs == {"a": 1, "b": 2, "c": 3, "d": 4}
    assert args == tuple()

    with pytest.raises(AssertionError):
        # too many keys should assert.
        args, kwargs = unpack_kwargs(["a", "b", "c", "d", "e"], flat_args)


def test_split_unpack():
    """Test split_non_tensors and unpack_non_tensors."""
    x = torch.Tensor([1])
    y = torch.Tensor([2])

    tensors, packed_non_tensors = split_non_tensors((x, y, None, 3))
    assert tensors == (x, y)
    assert packed_non_tensors == {
        "is_tensor": [True, True, False, False],
        "objects": [None, 3],
    }
    recon = unpack_non_tensors(tensors, packed_non_tensors)
    assert recon == (x, y, None, 3)

    tensors, packed_non_tensors = split_non_tensors((None, 3, x, y))
    recon = unpack_non_tensors(tensors, packed_non_tensors)
    assert recon == (None, 3, x, y)

    tensors, packed_non_tensors = split_non_tensors((None, 3))
    recon = unpack_non_tensors(tensors, packed_non_tensors)
    assert recon == (None, 3)

    tensors, packed_non_tensors = split_non_tensors((x, y))
    recon = unpack_non_tensors(tensors, packed_non_tensors)
    assert recon == (x, y)

    recon = unpack_non_tensors(tensors, None)
    assert recon == (x, y)

    with pytest.raises(AssertionError):
        # assert the second arg should be a dict.
        recon = unpack_non_tensors(tensors, set())

    with pytest.raises(AssertionError):
        # assert the content of the second arg should be sane.
        recon = unpack_non_tensors(tensors, {"is_tensor": [], "objects": []})