test_ssd_offload.py 8.11 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""
7
Testing SsdFlatParameter and SsdTensorHandle modules.
8
9
"""

10
11
import filecmp
import os
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import tempfile

import numpy as np
import pytest
import torch

import fairscale.experimental.nn.ssd_offload as so
from fairscale.utils import torch_version

# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires torch version >= 1.11.0")


def _init():
    torch.manual_seed(0)
    np.random.seed(0)


def test_write_read():
    _init()

    with tempfile.NamedTemporaryFile() as f:
        ref_tensor = torch.rand((128), dtype=torch.float32)
        test_tensor = torch.zeros_like(ref_tensor)
        assert not torch.equal(ref_tensor, test_tensor)
        so.write(ref_tensor, f.name)
        so.read(test_tensor, f.name)
        assert torch.equal(ref_tensor, test_tensor)


def test_ssd_handle_dispatch_fwd():
43
44
    _init()

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    with tempfile.NamedTemporaryFile() as f:
        orig_tensor = torch.randn((128))
        ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
        ssd_handle.set_file_params(f.name, 0)
        ssd_handle.to_file(release_tensor_after_write=True)

        assert torch.equal(ssd_handle.to_tensor(), orig_tensor)

        # This should trigger the torch_dispatch code and write
        # back the results to the file
        ssd_handle.add_(1)
        plus1_tensor = orig_tensor.add(1)
        assert torch.equal(ssd_handle.to_tensor(), plus1_tensor)


def test_ssd_handle_dispatch_bwd():
61
62
    _init()

63
64
65
66
67
68
69
70
71
72
73
74
75
76
    with tempfile.NamedTemporaryFile() as f:
        orig_tensor = torch.randn((4, 4), requires_grad=True)
        orig_copy = orig_tensor.clone().detach().requires_grad_(True)
        ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
        ssd_handle.set_file_params(f.name, 0)
        ssd_handle.to_file(release_tensor_after_write=True)

        assert torch.equal(ssd_handle.to_tensor(), orig_tensor)

        y1 = ssd_handle + 1
        y2 = orig_copy + 1
        y1.sum().backward()
        y2.sum().backward()

77
        assert torch.equal(ssd_handle.grad, orig_copy.grad)
78
79


80
def test_ssd_handle_train_simple():
81
82
83
    if torch_version() >= (1, 12, 0):
        pytest.skip("to be fixed")

84
85
    _init()

86
87
    with tempfile.NamedTemporaryFile() as f:
        orig_tensor = torch.randn((4, 4), requires_grad=True)
88

89
90
91
92
        with torch.no_grad():
            orig_copy = torch.empty_like(orig_tensor)
            orig_copy.copy_(orig_tensor)
            orig_copy.requires_grad = True
93

94
95
96
        ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
        ssd_handle.set_file_params(f.name, 0)
        ssd_handle.to_file(release_tensor_after_write=True)
97

98
99
100
        assert torch.equal(ssd_handle.to_tensor(), orig_tensor)
        optimizer_ssd = torch.optim.SGD([ssd_handle], lr=0.1)
        optimizer_orig = torch.optim.SGD([orig_copy], lr=0.1)
101

102
103
104
105
        y1 = ssd_handle + 1
        optimizer_ssd.zero_grad()
        y1.sum().backward()
        optimizer_ssd.step()
106

107
108
109
110
        y2 = orig_copy + 1
        optimizer_orig.zero_grad()
        y2.sum().backward()
        optimizer_orig.step()
111

112
113
114
        # make sure we are using the file version not the cached tensor
        ssd_handle.point_to_file(f.name, 0)
        assert torch.equal(ssd_handle.to_tensor(), orig_copy)
115
116


117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def test_torch_save_load_ssd_flat_param_on_disk():
    _init()
    orig_file = tempfile.NamedTemporaryFile(prefix="tensor")
    checkpoint_file = tempfile.NamedTemporaryFile(prefix="checkpoint", suffix=".pt")
    checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")

    # TENSOR_SHAPE = (1024, 1024, 2048)
    # use smaller shape for unit tests
    TENSOR_SHAPE = (1024, 321)
    ref_tensors = [torch.rand(TENSOR_SHAPE, dtype=torch.float32) for i in range(4)]
    ssd_handle = so.SsdFlatParameter.from_tensors(ref_tensors, False)
    ssd_handle.set_file_params(orig_file.name, 0)
    ssd_handle.to_file()
    ref_tensors = []

    # after deleting ref_tensor, memory usage should be very low
    # For save it shouldn't be more than 10x so.DEFAULT_CHUNK_SIZE
    with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name):
        so.torch_saver.save(ssd_handle, checkpoint_file.name)
    # below line saves file to checkpoint_load_directory/orig_file.name
    # Memory usage here should be O(1000 * so.DEFAULT_CHUNK_SIZE)
    # 1000x because that's how many elements the python unpickler
    # will buffer before passing to the SsdTensor
    test_ssd_handle = torch.load(checkpoint_file)
    head, tail = os.path.split(orig_file.name)
    assert filecmp.cmp(orig_file.name, os.path.join(checkpoint_load_directory.name, tail), shallow=False)


def test_torch_save_load_ssd_flat_param_on_mem():
    _init()
    orig_file = tempfile.NamedTemporaryFile(prefix="tensor")
    checkpoint_file = tempfile.NamedTemporaryFile(prefix="checkpoint", suffix=".pt")
    checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")

    # TENSOR_SHAPE = (1024, 1024, 2048)
    # use smaller shape for unit tests
    TENSOR_SHAPE = (1024, 321)
    ref_tensors = [torch.rand(TENSOR_SHAPE, dtype=torch.float32) for i in range(4)]
    ssd_handle = so.SsdFlatParameter.from_tensors(ref_tensors, False)
    ssd_handle.set_file_params(orig_file.name, 0)
    ref_tensors = []

    # after deleting ref_tensor, memory usage should be very low
    # For save it shouldn't be more than 10x so.DEFAULT_CHUNK_SIZE
    with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name):
        so.torch_saver.save(ssd_handle, checkpoint_file.name)
    # below line saves file to checkpoint_load_directory/orig_file.name
    # Memory usage here should be O(1000 * so.DEFAULT_CHUNK_SIZE)
    # 1000x because that's how many elements the python unpickler
    # will buffer before passing to the SsdTensor
    test_ssd_handle = torch.load(checkpoint_file)
    assert torch.equal(ssd_handle, test_ssd_handle)


def test_ssd_param_train_simple():
172
173
174
    if torch_version() >= (1, 12, 0):
        pytest.skip("to be fixed")

175
176
177
    _init()
    with tempfile.NamedTemporaryFile() as f:
        orig_tensor = torch.randn((4, 4))
178

179
180
181
        with torch.no_grad():
            orig_copy = torch.empty_like(orig_tensor)
            orig_copy.copy_(orig_tensor)
182
            param = torch.nn.Parameter(orig_copy)
183

184
185
186
187
        ssd_param = so.SsdParameter(orig_tensor.shape, orig_tensor.dtype)
        ssd_param.point_to_tensor(orig_copy)
        ssd_param.set_file_params(f.name, 0)
        ssd_param.to_file(release_tensor_after_write=True)
188

189
190
        assert torch.equal(ssd_param.to_tensor(), orig_tensor)
        optimizer_ssd = torch.optim.SGD([ssd_param], lr=0.1)
191
        optimizer_orig = torch.optim.SGD([param], lr=0.1)
192

193
        y1 = ssd_param + 1
194
195
196
        optimizer_ssd.zero_grad()
        y1.sum().backward()
        optimizer_ssd.step()
197

198
199
200
201
        y2 = param + 1
        optimizer_orig.zero_grad()
        y2.sum().backward()
        optimizer_orig.step()
202

203
        # make sure we are using the file version not the cached tensor
204
205
        ssd_param.point_to_file(f.name, 0)
        assert torch.equal(ssd_param.to_tensor(), param)
206
207


208
def test_ssd_flat_parameter_basic():
209
210
    _init()
    with tempfile.NamedTemporaryFile() as f:
211
212
213
        refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
        refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
        refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32))
214
215
        ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], False)
        ssd_flat_param.set_file_params(f.name, 0)
216

217
        param_views = list(ssd_flat_param.get_param_views())
218

219
220
221
        assert refa_param.shape == param_views[0].shape
        assert refb_param.shape == param_views[1].shape
        assert refc_param.shape == param_views[2].shape
222

223
224
225
226
        assert torch.equal(refa_param, param_views[0])
        assert torch.equal(refb_param, param_views[1])
        assert torch.equal(refc_param, param_views[2])
        ssd_flat_param.to_file()