test_ssd_offload.py 5.18 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""
7
Testing SsdFlatParameter and SsdTensorHandle modules.
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""

import tempfile

import numpy as np
import pytest
import torch

import fairscale.experimental.nn.ssd_offload as so
from fairscale.utils import torch_version

# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires torch version >= 1.11.0")


def _init():
    torch.manual_seed(0)
    np.random.seed(0)


def test_write_read():
    _init()

    with tempfile.NamedTemporaryFile() as f:
        ref_tensor = torch.rand((128), dtype=torch.float32)
        test_tensor = torch.zeros_like(ref_tensor)
        assert not torch.equal(ref_tensor, test_tensor)
        so.write(ref_tensor, f.name)
        so.read(test_tensor, f.name)
        assert torch.equal(ref_tensor, test_tensor)


def test_ssd_handle_dispatch_fwd():
41
42
    _init()

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    with tempfile.NamedTemporaryFile() as f:
        orig_tensor = torch.randn((128))
        ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
        ssd_handle.set_file_params(f.name, 0)
        ssd_handle.to_file(release_tensor_after_write=True)

        assert torch.equal(ssd_handle.to_tensor(), orig_tensor)

        # This should trigger the torch_dispatch code and write
        # back the results to the file
        ssd_handle.add_(1)
        plus1_tensor = orig_tensor.add(1)
        assert torch.equal(ssd_handle.to_tensor(), plus1_tensor)


def test_ssd_handle_dispatch_bwd():
59
60
    _init()

61
62
63
64
65
66
67
68
69
70
71
72
73
74
    with tempfile.NamedTemporaryFile() as f:
        orig_tensor = torch.randn((4, 4), requires_grad=True)
        orig_copy = orig_tensor.clone().detach().requires_grad_(True)
        ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
        ssd_handle.set_file_params(f.name, 0)
        ssd_handle.to_file(release_tensor_after_write=True)

        assert torch.equal(ssd_handle.to_tensor(), orig_tensor)

        y1 = ssd_handle + 1
        y2 = orig_copy + 1
        y1.sum().backward()
        y2.sum().backward()

75
        assert torch.equal(ssd_handle.grad, orig_copy.grad)
76
77


78
def test_ssd_handle_train_simple():
79
80
    _init()

81
82
    with tempfile.NamedTemporaryFile() as f:
        orig_tensor = torch.randn((4, 4), requires_grad=True)
83

84
85
86
87
        with torch.no_grad():
            orig_copy = torch.empty_like(orig_tensor)
            orig_copy.copy_(orig_tensor)
            orig_copy.requires_grad = True
88

89
90
91
        ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
        ssd_handle.set_file_params(f.name, 0)
        ssd_handle.to_file(release_tensor_after_write=True)
92

93
94
95
        assert torch.equal(ssd_handle.to_tensor(), orig_tensor)
        optimizer_ssd = torch.optim.SGD([ssd_handle], lr=0.1)
        optimizer_orig = torch.optim.SGD([orig_copy], lr=0.1)
96

97
98
99
100
        y1 = ssd_handle + 1
        optimizer_ssd.zero_grad()
        y1.sum().backward()
        optimizer_ssd.step()
101

102
103
104
105
        y2 = orig_copy + 1
        optimizer_orig.zero_grad()
        y2.sum().backward()
        optimizer_orig.step()
106

107
108
109
        # make sure we are using the file version not the cached tensor
        ssd_handle.point_to_file(f.name, 0)
        assert torch.equal(ssd_handle.to_tensor(), orig_copy)
110
111


112
113
114
115
def test_ssd_flat_param_train_simple():
    _init()
    with tempfile.NamedTemporaryFile() as f:
        orig_tensor = torch.randn((4, 4))
116

117
118
119
120
        with torch.no_grad():
            orig_copy = torch.empty_like(orig_tensor)
            orig_copy.copy_(orig_tensor)
        param = torch.nn.Parameter(orig_copy)
121

122
        ssd_flat_param = so.SsdFlatParameter([param], f.name, True)
123

124
125
126
        assert torch.equal(list(ssd_flat_param.get_param_views())[0], orig_tensor)
        optimizer_ssd = torch.optim.SGD([ssd_flat_param], lr=0.1)
        optimizer_orig = torch.optim.SGD([param], lr=0.1)
127

128
129
130
131
        y1 = ssd_flat_param + 1
        optimizer_ssd.zero_grad()
        y1.sum().backward()
        optimizer_ssd.step()
132

133
134
135
136
        y2 = param + 1
        optimizer_orig.zero_grad()
        y2.sum().backward()
        optimizer_orig.step()
137

138
139
140
        # make sure we are using the file version not the cached tensor
        ssd_flat_param.point_to_file(f.name, 0)
        assert torch.equal(list(ssd_flat_param.get_param_views())[0], param)
141
142


143
def test_ssd_flat_parameter_basic():
144
145
    _init()
    with tempfile.NamedTemporaryFile() as f:
146
147
148
149
        refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
        refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
        refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32))
        ssd_flat_param = so.SsdFlatParameter([refa_param, refb_param, refc_param], f.name, False)
150

151
        param_views = list(ssd_flat_param.get_param_views())
152

153
154
155
        assert refa_param.shape == param_views[0].shape
        assert refb_param.shape == param_views[1].shape
        assert refc_param.shape == param_views[2].shape
156

157
158
159
160
        assert torch.equal(refa_param, param_views[0])
        assert torch.equal(refb_param, param_views[1])
        assert torch.equal(refc_param, param_views[2])
        ssd_flat_param.to_file()