Commit 0cd65242 authored by Mandeep Singh Baines's avatar Mandeep Singh Baines
Browse files

Initial commit

parents
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Iterator, Optional, Sequence, List, TypeVar, Generic, Sized
T_co = TypeVar('T_co', covariant=True)
class Sampler(Generic[T_co]):
def __init__(self, data_source: Sized) -> None: ...
def __iter__(self) -> Iterator[T_co]: ...
def __len__(self) -> int: ...
class SequentialSampler(Sampler[int]):
pass
class RandomSampler(Sampler[int]):
num_samples: int
def __init__(self, data_source: Sized, replacement: bool=..., num_samples: Optional[int]=...) -> None: ...
class SubsetRandomSampler(Sampler[int]):
def __init__(self, indices: Sequence[int]) -> None: ...
class WeightedRandomSampler(Sampler[int]):
def __init__(self, weights: Sequence[float], num_samples: int, replacement: bool=...) -> None: ...
class BatchSampler(Sampler[List[int]]):
def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None: ...
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#MODIFIED BY TORCHGPIPE
debug: bool = ...
cuda: str = ...
git_version: str = ...
#END
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# tests/__init__.py makes pytest can import the application without custom sys.path or PYTHONPATH.
# See also: https://docs.pytest.org/en/latest/goodpractices.html
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
@pytest.fixture(autouse=True)
def manual_seed_zero():
torch.manual_seed(0)
@pytest.fixture(scope="session")
def cuda_sleep():
# Warm-up CUDA.
torch.empty(1, device="cuda")
# From test/test_cuda.py in PyTorch.
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
torch.cuda._sleep(1000000)
end.record()
end.synchronize()
cycles_per_ms = 1000000 / start.elapsed_time(end)
def cuda_sleep(seconds):
torch.cuda._sleep(int(seconds * cycles_per_ms * 1000))
return cuda_sleep
def pytest_report_header():
return f"torch: {torch.__version__}"
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from torch import nn
from fairscale.nn.pipe.skip import Namespace, skippable, stash
def test_namespace_difference():
ns1 = Namespace()
ns2 = Namespace()
assert ns1 != ns2
def test_namespace_copy():
ns = Namespace()
assert copy.copy(ns) == ns
assert copy.copy(ns) is not ns
def test_skippable_repr():
@skippable(stash=["hello"])
class Hello(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 1)
def forward(self, x):
yield stash("hello", x)
return self.conv(x)
m = Hello()
assert (
repr(m)
== """
@skippable(Hello(
(conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
))
""".strip()
)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from torch import nn
from fairscale.nn.pipe import Pipe
from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_1to3(balance, checkpoint):
if torch.cuda.device_count() < len(balance):
pytest.skip("at least %d cuda devices required" % len(balance))
@skippable(stash=["1to3"])
class Layer1(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
def forward(self, input):
yield stash("1to3", input)
output = self.conv(input)
return output
class Layer2(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
def forward(self, input):
output = self.conv(input)
return output
@skippable(pop=["1to3"])
class Layer3(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
def forward(self, input):
skip_1to3 = yield pop("1to3")
output = self.conv(input) + skip_1to3
return output
model = nn.Sequential(Layer1(), Layer2(), Layer3())
model = Pipe(model, balance, chunks=3, checkpoint=checkpoint)
in_device = model.devices[0]
out_device = model.devices[-1]
input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True)
output = model(input)
loss = output.mean()
loss.backward()
assert torch.allclose(output.norm(), torch.tensor(1039.0, device=out_device), atol=2e-1)
assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053, device=in_device))
def test_none_skip():
@skippable(stash=["none"])
class Stash(nn.Module):
def forward(self, input):
yield stash("none", None)
return input
@skippable(pop=["none"])
class Pop(nn.Module):
def forward(self, input):
none = yield pop("none")
assert none is None
return input
model = nn.Sequential(Stash(), Pop())
model = Pipe(model, [1, 1], devices=["cpu", "cpu"], chunks=5)
input = torch.rand(10, requires_grad=True)
output = model(input)
def assert_grad_fn_is_not_portal(grad_fn, visited=set()):
if grad_fn in visited or grad_fn is None:
return
assert not isinstance(grad_fn, PortalBlue._backward_cls)
assert not isinstance(grad_fn, PortalCopy._backward_cls)
assert not isinstance(grad_fn, PortalOrange._backward_cls)
visited.add(grad_fn)
for next_grad_fn, _ in grad_fn.next_functions:
assert_grad_fn_is_not_portal(next_grad_fn, visited)
assert_grad_fn_is_not_portal(output.grad_fn)
output.sum().backward()
assert input.grad.mean().item() == 1
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch import nn
from fairscale.nn.pipe.skip import Namespace, pop, skippable, stash
from fairscale.nn.pipe.skip.layout import inspect_skip_layout
class Pass(nn.Module):
def forward(self, input):
return input
@skippable(stash=["foo"])
class StashFoo(nn.Module):
def forward(self, input):
yield stash("foo", input)
return input
@skippable(pop=["foo"])
class PopFoo(nn.Module):
def forward(self, input):
foo = yield stash("foo")
return input + foo
@skippable(stash=["bar"])
class StashBar(nn.Module):
def forward(self, input):
yield stash("bar", input)
return input
@skippable(pop=["bar"])
class PopBar(nn.Module):
def forward(self, input):
bar = yield pop("bar")
return input + bar
def test_no_skippables():
p1 = nn.Sequential(Pass())
p2 = nn.Sequential(Pass())
layout = inspect_skip_layout([p1, p2])
policy = [list(layout.copy_policy(i)) for i in range(2)]
assert policy == [[], []]
def test_inner_partition():
p1 = nn.Sequential(StashFoo(), PopFoo())
p2 = nn.Sequential(Pass())
layout = inspect_skip_layout([p1, p2])
policy = [list(layout.copy_policy(i)) for i in range(2)]
assert policy == [[], []]
def test_adjoining_partitions():
p1 = nn.Sequential(StashFoo())
p2 = nn.Sequential(PopFoo())
layout = inspect_skip_layout([p1, p2])
policy = [list(layout.copy_policy(i)) for i in range(2)]
assert policy == [[], [(0, None, "foo")]]
def test_far_partitions():
p1 = nn.Sequential(StashFoo())
p2 = nn.Sequential(Pass())
p3 = nn.Sequential(PopFoo())
layout = inspect_skip_layout([p1, p2, p3])
policy = [list(layout.copy_policy(i)) for i in range(3)]
assert policy == [[], [], [(0, None, "foo")]]
def test_pop_2_from_different_partitions():
p1 = nn.Sequential(StashFoo())
p2 = nn.Sequential(StashBar())
p3 = nn.Sequential(PopBar(), PopFoo())
layout = inspect_skip_layout([p1, p2, p3])
policy = [list(layout.copy_policy(i)) for i in range(3)]
# p3 pops 'bar' before 'foo', but the plan is sorted by source partition index.
assert policy == [[], [], [(0, None, "foo"), (1, None, "bar")]]
def test_namespace():
ns1 = Namespace()
ns2 = Namespace()
p1 = nn.Sequential(StashFoo().isolate(ns1))
p2 = nn.Sequential(StashFoo().isolate(ns2))
p3 = nn.Sequential(PopFoo().isolate(ns2), PopFoo().isolate(ns1))
layout = inspect_skip_layout([p1, p2, p3])
policy = [list(layout.copy_policy(i)) for i in range(3)]
# p3 pops 'bar' before 'foo', but the plan is sorted by source partition index.
assert policy == [[], [], [(0, ns1, "foo"), (1, ns2, "foo")]]
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from torch import nn
from fairscale.nn.pipe import Pipe, is_checkpointing, is_recomputing
from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.tracker import current_skip_tracker
@skippable(stash=["skip"])
class Stash(nn.Module):
def forward(self, input):
yield stash("skip", input)
return input
@skippable(pop=["skip"])
class Pop(nn.Module):
def forward(self, input):
skip = yield pop("skip")
return input + skip
@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"])
def test_delete_portal_tensor(train, checkpoint):
# Without checkpointing:
# +- Stash --+ +--- Pop ----+ - - - layers
# | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function
# +----------+ +------------+
#
# With checkpointing:
# +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# +----------+ +------------+ +------------+ +----------+
def portal_tensor_life_is(tensor_life, skip_tracker=None):
if skip_tracker is None:
skip_tracker = current_skip_tracker()
# Get the current portal.
portal = list(skip_tracker.portals.values())[0]
if tensor_life == 0:
return portal.tensor_life == 0 and portal.tensor is None
else:
return portal.tensor_life == tensor_life and portal.tensor is not None
# Check the portal tensor after 'Stash'.
stash_ = Stash()
@stash_.register_forward_hook
def check_portal_tensor_after_stash(*_):
if is_checkpointing():
assert portal_tensor_life_is(2)
elif is_recomputing():
assert portal_tensor_life_is(0)
else:
assert portal_tensor_life_is(1)
pop_ = Pop()
@pop_.register_forward_hook
def check_portal_tensor_after_pop(*_):
if is_checkpointing():
assert portal_tensor_life_is(1)
elif is_recomputing():
assert portal_tensor_life_is(0)
else:
assert portal_tensor_life_is(0)
class NoPortalTensorAtBackward(nn.Module):
class F(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.skip_tracker = current_skip_tracker()
return input.detach()
@staticmethod
def backward(ctx, grad):
assert portal_tensor_life_is(0, skip_tracker=ctx.skip_tracker)
return grad
def forward(self, input):
return self.F.apply(input)
model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_)
model = Pipe(model, balance=[2, 1], devices=["cpu", "cpu"], chunks=2, checkpoint=checkpoint)
input = torch.rand(10, requires_grad=True)
if train:
model.train()
output = model(input)
output.norm().backward()
else:
model.eval()
with torch.no_grad():
model(input)
@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
def test_no_portal_without_pipe(train, monkeypatch):
def deny(*args, **kwargs):
raise AssertionError("tried to create Portal without Pipe")
monkeypatch.setattr("fairscale.nn.pipe.skip.portal.Portal.__init__", deny)
model = nn.Sequential(Stash(), Pop())
input = torch.rand(10, requires_grad=True)
if train:
model.train()
output = model(input)
output.norm().backward()
else:
model.eval()
with torch.no_grad():
model(input)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from fairscale.nn.pipe.dependency import fork, join
from fairscale.nn.pipe.skip.portal import Portal
from fairscale.nn.pipe.stream import default_stream
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_copy_returns_on_next_device():
portal = Portal(torch.rand(1), tensor_life=1)
prev_stream = default_stream(torch.device("cpu"))
next_stream = default_stream(torch.device("cuda"))
phony = torch.zeros(0, requires_grad=True)
assert phony.device.type == "cpu"
phony = portal.copy(prev_stream, next_stream, phony)
assert phony.device.type == "cuda"
def test_blue_orange():
tensor1 = torch.rand(1, requires_grad=True)
tensor2 = torch.rand(1, requires_grad=True)
# Same with: output = tensor1*2 + tensor2
#
# +----------------------+
# | |
# tensor2 -- PortalBlue -+ +- PortalOrange -+
# | | |
# tensor1 ------------ Join -- Fork --- Mul --- Add -- output
#
main = tensor1
portal = Portal(tensor2, tensor_life=2)
phony = portal.blue()
main = join(main, phony)
main, phony = fork(main)
sub = portal.orange(phony)
output = main * 2 + sub
output.backward()
assert torch.allclose(tensor1.grad, torch.tensor([2.0]))
assert torch.allclose(tensor2.grad, torch.tensor([1.0]))
def test_blue_orange_not_requires_grad():
tensor1 = torch.rand(1, requires_grad=True)
tensor2 = torch.rand(1)
# Same with: output = tensor1*2 + tensor2
#
# +----------------------+
# | |
# tensor2 -- PortalBlue -+ +- PortalOrange -+
# | | |
# tensor1 ------------ Join -- Fork --- Mul --- Add -- output
#
main = tensor1
portal = Portal(tensor2, tensor_life=2)
phony = portal.blue()
main = join(main, phony)
main, phony = fork(main)
sub = portal.orange(phony)
output = main * 2 + sub
output.backward()
assert torch.allclose(tensor1.grad, torch.tensor([2.0]))
assert tensor2.grad is None
def test_use_grad():
tensor = torch.rand(1, requires_grad=True)
portal = Portal(tensor, tensor_life=1)
portal.put_grad(tensor)
assert portal.use_grad() is tensor
# Gradient in a portal is ephemeral.
with pytest.raises(RuntimeError):
portal.use_grad()
class TestTensorLife:
@pytest.fixture
def new_portal(self):
portal = None
def new_portal(tensor_life):
nonlocal portal
tensor = torch.rand(1, requires_grad=True)
portal = Portal(tensor, tensor_life)
return portal, tensor
yield new_portal
# A test using this fixture must exhaust the tensor in the portal.
with pytest.raises(RuntimeError):
portal.check_tensor_life()
assert portal.tensor is None
def test_tensor_life_0(self, new_portal):
portal, tensor = new_portal(0)
assert portal.tensor is None
def test_tensor_life_1(self, new_portal):
portal, tensor = new_portal(1)
assert portal.tensor is tensor
portal.blue()
def test_tensor_life_2(self, new_portal):
portal, tensor = new_portal(2)
assert portal.tensor is tensor
phony = portal.blue()
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
def test_tensor_life_3(self, new_portal):
portal, tensor = new_portal(3)
assert portal.tensor is tensor
phony = portal.blue()
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
def test_tensor_life_4(self, new_portal):
portal, tensor = new_portal(4)
assert portal.tensor is tensor
phony = portal.blue()
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
portal.blue()
def test_tensor_life_3_plus_1(self, new_portal):
portal, tensor = new_portal(3)
assert portal.tensor is tensor
phony = portal.blue()
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
another_tensor = torch.rand(1, requires_grad=True)
portal.put_tensor(another_tensor, tensor_life=1)
portal.blue()
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from torch import nn
from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.tracker import SkipTracker, use_skip_tracker
@pytest.fixture(autouse=True)
def skip_tracker():
skip_tracker = SkipTracker()
with use_skip_tracker(skip_tracker):
yield skip_tracker
def test_stash(skip_tracker):
@skippable(stash=["foo"])
class Stash(nn.Module):
def forward(self, input):
yield stash("foo", input)
return input * 2
l1 = Stash()
assert len(skip_tracker.tensors) == 0
with use_skip_tracker(skip_tracker):
l1(torch.tensor(42))
assert len(skip_tracker.tensors) == 1
def test_pop():
@skippable(stash=["foo"])
class Stash(nn.Module):
def forward(self, input):
yield stash("foo", input)
return input * 2
@skippable(pop=["foo"])
class Pop(nn.Module):
def forward(self, input):
foo = yield pop("foo")
return foo
l1 = Stash()
l2 = Pop()
output = l2(l1(torch.tensor(42)))
assert output.item() == 42
def test_declare_but_not_use():
@skippable(stash=["foo"])
class Stash(nn.Module):
def forward(self, input):
return input * 2
@skippable(pop=["foo"])
class Pop(nn.Module):
def forward(self, input):
return input * 3
l1 = Stash()
l2 = Pop()
with pytest.raises(RuntimeError):
l1(torch.tensor(42))
with pytest.raises(RuntimeError):
l2(torch.tensor(42))
def test_stash_not_declared():
@skippable()
class Stash(nn.Module):
def forward(self, input):
yield stash("foo", input)
return input * 2
l1 = Stash()
with pytest.raises(RuntimeError):
l1(torch.tensor(42))
def test_pop_not_declared():
@skippable(stash=["foo"])
class Stash(nn.Module):
def forward(self, input):
yield stash("foo", input)
return input * 2
@skippable()
class Pop(nn.Module):
def forward(self, input):
foo = yield pop("foo")
return foo
l1 = Stash()
l2 = Pop()
latent = l1(torch.tensor(42))
with pytest.raises(RuntimeError):
l2(latent)
def test_pop_not_stashed():
@skippable(pop=["foo"])
class Pop(nn.Module):
def forward(self, input):
yield pop("foo")
l1 = Pop()
with pytest.raises(RuntimeError):
l1(torch.tensor(42))
def test_stash_none():
@skippable(stash=["foo"])
class Stash(nn.Module):
def forward(self, input):
yield stash("foo", None)
return input * 2
l1 = Stash()
l1(torch.tensor(42))
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from queue import Queue
import threading
import pytest
import torch
from torch import nn
from fairscale.nn.pipe.checkpoint import enable_checkpointing, enable_recomputing
from fairscale.nn.pipe.microbatch import Batch
from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.layout import SkipLayout
from fairscale.nn.pipe.skip.tracker import SkipTracker, SkipTrackerThroughPotals, current_skip_tracker
def test_default_skip_tracker():
q = Queue()
def f():
q.put(current_skip_tracker())
t = threading.Thread(target=f)
t.start()
t.join()
skip_tracker = q.get()
assert type(skip_tracker) is SkipTracker
assert type(skip_tracker) is not SkipTrackerThroughPotals
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_default_skip_tracker_by_data_parallel():
@skippable(stash=["foo"])
class Stash(nn.Module):
def forward(self, input):
yield stash("foo", input)
return input * 2
@skippable(pop=["foo"])
class Pop(nn.Module):
def forward(self, input):
foo = yield pop("foo")
return foo
model = nn.Sequential(Stash(), Pop())
model = nn.DataParallel(model, device_ids=[0, 0], output_device=0)
input = torch.rand(10, device=0)
output = model(input)
assert torch.allclose(output, input)
def test_reuse_portal():
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)})
skip_tracker = SkipTrackerThroughPotals(skip_layout)
batch = Batch(torch.tensor([1.0]))
a = torch.tensor([2.0])
b = torch.tensor([2.0])
skip_tracker.save(batch, None, "test", a)
portal = skip_tracker.portals[(None, "test")]
skip_tracker.save(batch, None, "test", b)
assert portal is skip_tracker.portals[(None, "test")]
def test_no_copy_no_portal():
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "copy"): (0, 1), (None, "not_copy"): (0, 0)})
skip_tracker = SkipTrackerThroughPotals(skip_layout)
batch = Batch(torch.tensor([1.0]))
a = torch.tensor([2.0])
b = torch.tensor([2.0])
skip_tracker.save(batch, None, "copy", a)
skip_tracker.save(batch, None, "not_copy", b)
assert (None, "copy") in skip_tracker.portals
assert (None, "copy") not in skip_tracker.tensors
assert (None, "not_copy") in skip_tracker.tensors
assert (None, "not_copy") not in skip_tracker.portals
def test_tensor_life_without_checkpointing():
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)})
skip_tracker = SkipTrackerThroughPotals(skip_layout)
batch = Batch(torch.tensor([1.0]))
tensor = torch.tensor([2.0])
skip_tracker.save(batch, None, "test", tensor)
assert skip_tracker.portals[(None, "test")].tensor_life == 1
skip_tracker.load(batch, None, "test")
assert skip_tracker.portals[(None, "test")].tensor_life == 0
def test_tensor_life_with_checkpointing():
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)})
skip_tracker = SkipTrackerThroughPotals(skip_layout)
batch = Batch(torch.tensor([1.0]))
tensor = torch.tensor([2.0])
with enable_checkpointing():
skip_tracker.save(batch, None, "test", tensor)
assert skip_tracker.portals[(None, "test")].tensor_life == 2
with enable_checkpointing():
skip_tracker.load(batch, None, "test")
assert skip_tracker.portals[(None, "test")].tensor_life == 1
with enable_recomputing():
skip_tracker.load(batch, None, "test")
assert skip_tracker.portals[(None, "test")].tensor_life == 0
with enable_recomputing():
skip_tracker.save(batch, None, "test", tensor)
assert skip_tracker.portals[(None, "test")].tensor_life == 0
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from torch import nn
from fairscale.nn.pipe.skip import Namespace, skippable, verify_skippables
def test_matching():
@skippable(stash=["foo"])
class Layer1(nn.Module):
pass
@skippable(pop=["foo"])
class Layer2(nn.Module):
pass
verify_skippables(nn.Sequential(Layer1(), Layer2()))
def test_stash_not_pop():
@skippable(stash=["foo"])
class Layer1(nn.Module):
pass
with pytest.raises(TypeError) as e:
verify_skippables(nn.Sequential(Layer1()))
assert "no module declared 'foo' as poppable but stashed" in str(e.value)
def test_pop_unknown():
@skippable(pop=["foo"])
class Layer1(nn.Module):
pass
with pytest.raises(TypeError) as e:
verify_skippables(nn.Sequential(Layer1()))
assert "'0' declared 'foo' as poppable but it was not stashed" in str(e.value)
def test_stash_again():
@skippable(stash=["foo"])
class Layer1(nn.Module):
pass
@skippable(stash=["foo"])
class Layer2(nn.Module):
pass
@skippable(pop=["foo"])
class Layer3(nn.Module):
pass
with pytest.raises(TypeError) as e:
verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3()))
assert "'1' redeclared 'foo' as stashable" in str(e.value)
def test_pop_again():
@skippable(stash=["foo"])
class Layer1(nn.Module):
pass
@skippable(pop=["foo"])
class Layer2(nn.Module):
pass
@skippable(pop=["foo"])
class Layer3(nn.Module):
pass
with pytest.raises(TypeError) as e:
verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3()))
assert "'2' redeclared 'foo' as poppable" in str(e.value)
def test_stash_pop_together_different_names():
@skippable(stash=["foo"])
class Layer1(nn.Module):
pass
@skippable(pop=["foo"], stash=["bar"])
class Layer2(nn.Module):
pass
@skippable(pop=["bar"])
class Layer3(nn.Module):
pass
verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3()))
def test_stash_pop_together_same_name():
@skippable(stash=["foo"], pop=["foo"])
class Layer1(nn.Module):
pass
with pytest.raises(TypeError) as e:
verify_skippables(nn.Sequential(Layer1()))
assert "'0' declared 'foo' both as stashable and as poppable" in str(e.value)
def test_double_stash_pop():
@skippable(stash=["foo"])
class Layer1(nn.Module):
pass
@skippable(pop=["foo"])
class Layer2(nn.Module):
pass
@skippable(stash=["foo"])
class Layer3(nn.Module):
pass
@skippable(pop=["foo"])
class Layer4(nn.Module):
pass
with pytest.raises(TypeError) as e:
verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3(), Layer4()))
assert "'2' redeclared 'foo' as stashable" in str(e.value)
assert "'3' redeclared 'foo' as poppable" in str(e.value)
def test_double_stash_pop_but_isolated():
@skippable(stash=["foo"])
class Layer1(nn.Module):
pass
@skippable(pop=["foo"])
class Layer2(nn.Module):
pass
@skippable(stash=["foo"])
class Layer3(nn.Module):
pass
@skippable(pop=["foo"])
class Layer4(nn.Module):
pass
ns1 = Namespace()
ns2 = Namespace()
verify_skippables(
nn.Sequential(Layer1().isolate(ns1), Layer2().isolate(ns1), Layer3().isolate(ns2), Layer4().isolate(ns2),)
)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import pytest
import torch
from torch import nn
from fairscale.nn.pipe.balance import balance_by_size, balance_by_time, blockpartition
from fairscale.nn.pipe.balance.profile import layerwise_sandbox
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
def test_blockpartition():
assert blockpartition.solve([1, 2, 3, 4, 5, 6], partitions=2) == [[1, 2, 3, 4], [5, 6]]
def test_blockpartition_zeros():
assert blockpartition.solve([0, 0], partitions=2) == [[0], [0]]
def test_blockpartition_non_positive_partitions():
with pytest.raises(ValueError):
blockpartition.solve([42], partitions=0)
with pytest.raises(ValueError):
blockpartition.solve([42], partitions=-1)
def test_blockpartition_short_sequence():
with pytest.raises(ValueError):
blockpartition.solve([], partitions=1)
with pytest.raises(ValueError):
blockpartition.solve([42], partitions=2)
@pytest.mark.parametrize("device", devices)
def test_balance_by_time(device):
class Delay(nn.Module):
def __init__(self, seconds):
super().__init__()
self.seconds = seconds
def forward(self, x):
time.sleep(self.seconds)
return x
model = nn.Sequential(*[Delay(i / 100) for i in [1, 2, 3, 4, 5, 6]])
sample = torch.rand(1)
balance = balance_by_time(2, model, sample, device=device)
assert balance == [4, 2]
def test_balance_by_time_loop_resets_input():
# nn.Flatten was introduced at PyTorch 1.2.0.
class Flatten(nn.Module):
def forward(self, x):
return x.flatten(1)
model = nn.Sequential(nn.Conv2d(3, 2, 1), Flatten(), nn.Linear(128, 10))
sample = torch.rand(10, 3, 8, 8)
balance = balance_by_time(2, model, sample, device="cpu")
assert balance == [1, 2]
@skip_if_no_cuda
def test_balance_by_size_latent():
class Expand(nn.Module):
def __init__(self, times):
super().__init__()
self.times = times
def forward(self, x):
for i in range(self.times):
x = x + torch.rand_like(x, requires_grad=True)
return x
sample = torch.rand(10, 100, 100)
model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]])
balance = balance_by_size(2, model, sample)
assert balance == [4, 2]
model = nn.Sequential(*[Expand(i) for i in [6, 5, 4, 3, 2, 1]])
balance = balance_by_size(2, model, sample)
assert balance == [2, 4]
@skip_if_no_cuda
def test_balance_by_size_param():
model = nn.Sequential(*[nn.Linear(i + 1, i + 2) for i in range(6)])
sample = torch.rand(7, 1)
balance = balance_by_size(2, model, sample, param_scale=100)
assert balance == [4, 2]
model = nn.Sequential(*[nn.Linear(i + 2, i + 1) for i in reversed(range(6))])
sample = torch.rand(1, 7)
balance = balance_by_size(2, model, sample, param_scale=100)
assert balance == [2, 4]
@skip_if_no_cuda
def test_balance_by_size_param_scale():
class Tradeoff(nn.Module):
def __init__(self, param_size, latent_size):
super().__init__()
self.fc = nn.Linear(param_size, param_size)
self.latent_size = latent_size
def forward(self, x):
for i in range(self.latent_size):
x = x + torch.rand_like(x, requires_grad=True)
return x
model = nn.Sequential(
Tradeoff(param_size=1, latent_size=6),
Tradeoff(param_size=2, latent_size=5),
Tradeoff(param_size=3, latent_size=4),
Tradeoff(param_size=4, latent_size=3),
Tradeoff(param_size=5, latent_size=2),
Tradeoff(param_size=6, latent_size=1),
)
sample = torch.rand(1, requires_grad=True)
balance = balance_by_size(2, model, sample, param_scale=0)
assert balance == [2, 4]
balance = balance_by_size(2, model, sample, param_scale=100)
assert balance == [4, 2]
@pytest.mark.parametrize("device", devices)
def test_layerwise_sandbox(device):
model = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3))
model.eval()
for layer in layerwise_sandbox(model, torch.device(device)):
assert layer.training
assert all(p.device.type == device for p in layer.parameters())
assert all(not l.training for l in model)
assert all(p.device.type == "cpu" for p in model.parameters())
@pytest.mark.parametrize("device", devices)
def test_sandbox_during_profiling(device):
model = nn.Sequential(nn.BatchNorm2d(3))
before = {k: v.clone() for k, v in model.state_dict().items()}
sample = torch.rand(1, 3, 10, 10)
balance_by_time(1, model, sample, device=device)
after = model.state_dict()
assert before.keys() == after.keys()
for key, value in before.items():
assert torch.allclose(after[key], value), key
def test_not_training():
class AssertTraining(nn.Module):
def forward(self, x):
assert self.training
return x
model = nn.Sequential(AssertTraining())
model.eval()
assert not model.training
sample = torch.rand(1)
balance_by_time(1, model, sample, device="cpu")
assert not model.training
def test_balance_by_time_tuple():
class Twin(nn.Module):
def forward(self, x):
return x, x.detach()
class Add(nn.Module):
def forward(self, a_b):
a, b = a_b
return a + b
model = nn.Sequential(Twin(), Add())
sample = torch.rand(1, requires_grad=True)
balance_by_time(1, model, sample, device="cpu")
@skip_if_no_cuda
def test_balance_by_size_tuple():
class Twin(nn.Module):
def forward(self, x):
return x, x.detach()
class Add(nn.Module):
def forward(self, a_b):
a, b = a_b
return a + b
model = nn.Sequential(Twin(), Add())
sample = torch.rand(1, requires_grad=True)
balance_by_size(1, model, sample)
def test_already_has_grad():
model = nn.Sequential(nn.Conv2d(3, 3, 1))
sample = torch.rand(1, 3, 32, 32)
model(sample).norm().backward()
with pytest.raises(ValueError, match="some parameter already has gradient"):
balance_by_time(1, model, sample, device="cpu")
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from torch import nn
import torch.nn.functional as F
from fairscale.nn.pipe import Pipe
def test_python_autograd_function():
# A Python autograd function might fail with this error:
#
# RuntimeError: Returning Variables sharing storage with other Variables
# that require grad is not supported in Python functions. Please submit a
# feature request if you hit this error.
#
# It doesn't look like an essential restriction. But it happens on the
# current PyTorch version. To avoid it, we should detach the tensor before
# returning by identity autograd functions, such as Wait, Fork, and Join.
#
class Identity(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad):
return grad
class M(nn.Module):
def forward(self, input):
return Identity.apply(input)
model = nn.Sequential(M(), M())
model = Pipe(model, [1, 1], devices=["cpu", "cpu"], checkpoint="always")
x = torch.rand(42)
y = model(x)
assert torch.allclose(x, y)
def test_exception_no_hang():
# In v0.0.2, once a failed partition receives a normal message
# (non-closing) for the next micro-batch, a hang occured. The reason was
# that a failed partition didn't call in_queue.task_done() on a normal
# message. So the former partition was blocked at out_queue.join() for the
# next of next micro-batch.
class ExpectedException(Exception):
pass
class Pass(nn.Module):
def forward(self, x):
return x
class Raise(nn.Module):
def forward(self, x):
raise ExpectedException()
model = nn.Sequential(Pass(), Pass(), Raise())
model = Pipe(model, [1, 1, 1], devices=["cpu", "cpu", "cpu"], chunks=3)
with pytest.raises(ExpectedException):
model(torch.rand(3))
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required")
def test_tuple_wait(cuda_sleep):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility
# that gradient accumulations on other tensors are not synchronized
# properly to the copy stream.
class Sleep(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.detach()
@staticmethod
def backward(ctx, grad):
with torch.cuda.device(grad.device):
cuda_sleep(0.05)
return grad
class Layer1(nn.Module):
def forward(self, pair):
a, b = pair
return a * 1, b * 2, b * 3
class Layer2(nn.Module):
def forward(self, triple):
a, b, c = triple
b = Sleep.apply(b)
return a + b + c
model = nn.Sequential(Layer1(), Layer2())
model = Pipe(model, [1, 1], devices=[0, 1], chunks=32, checkpoint="never")
a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True)
b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True)
y = model((a, b))
y.norm().backward()
torch.cuda.synchronize(0)
torch.cuda.synchronize(1)
assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000))
def test_parallel_randoms():
class Dropouts(nn.Module):
def forward(self, x):
for _ in range(100):
x = F.dropout(x, p=0.001)
return x
model = nn.Sequential(Dropouts(), Dropouts())
x = torch.rand(10, 10, requires_grad=True)
model = Pipe(model, [1, 1], devices=["cpu", "cpu"], chunks=10, checkpoint="always")
y = model(x)
y.norm().backward()
assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist()
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import pytest
import torch
from torch import nn
import torch.cuda
from fairscale.nn.pipe.checkpoint import Checkpointing, checkpoint, is_checkpointing, is_recomputing
from fairscale.nn.pipe.dependency import fork, join
from fairscale.nn.pipe.microbatch import Batch
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
@pytest.mark.parametrize("device", devices)
def test_serial_checkpoints(device):
# Copied from https://github.com/pytorch/pytorch/pull/18568.
timeline = []
class Log(torch.autograd.Function):
@staticmethod
def forward(ctx, name, x):
ctx.name = name
timeline.append(f"{name}:forward")
return x.detach()
@staticmethod
def backward(ctx, grad_output):
name = ctx.name
timeline.append(f"{name}:backward")
return None, grad_output
a = torch.rand(1, device=device, requires_grad=True)
b = torch.rand(1, device=device, requires_grad=True)
# Increase the next function sequence number.
_ = a + 1 + 2 + 3 + 4 + 5
a = checkpoint(partial(Log.apply, "a"), a)
a, phony = fork(a)
b = join(b, phony)
b = checkpoint(partial(Log.apply, "b"), b)
c = torch.cat((a, b))
out = c.sum()
# +--> {a} --Checkpoint(Log)--> {a}
# {out} --Sum--> {c} --Cat ^-----------------------------+
# +--> {b} --Checkpoint(Log)--> {b} --First--> {b}
out.backward()
assert timeline == ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"]
# |----------------------| |-----------------------| |-----------------------|
# forward pass Checkpoint(Log[b]) Checkpoint(Log[a])
def test_not_requires_grad():
x = Batch(torch.rand(1, requires_grad=False))
assert not x[0].requires_grad
def f(x):
return x * 2
chk = Checkpointing(f, x)
x = chk.checkpoint()
assert x[0].requires_grad
chk.recompute(x)
assert x[0].requires_grad
x.tensor.backward()
def test_not_requires_grad_with_parameter():
x = torch.rand(1, requires_grad=False)
a = torch.rand(1, requires_grad=True)
def f(x):
return x * a
y = checkpoint(f, x)
y.backward()
assert a.grad is not None
@pytest.mark.parametrize("device", devices)
def test_random_in_checkpoint(device):
dropout = nn.Dropout(p=0.5)
torch.manual_seed(0)
x = torch.randn(3, 3, device=device, requires_grad=True)
y = dropout(x)
y.norm().backward()
torch.manual_seed(0)
chk_x = torch.randn(3, 3, device=device, requires_grad=True)
chk_y = checkpoint(dropout, chk_x)
chk_y.norm().backward()
assert torch.allclose(x.grad, chk_x.grad)
def test_detect_checkpointing_recomputing():
logs = []
class Detect(nn.Module):
def forward(self, input):
logs.append((is_checkpointing(), is_recomputing()))
return input
model = Detect()
input = torch.rand(1, requires_grad=True)
output = checkpoint(model, input)
output.backward()
assert logs == [(True, False), (False, True)]
def test_detect_checkpointing_recomputing_without_checkpoint():
logs = []
class Detect(nn.Module):
def forward(self, input):
logs.append((is_checkpointing(), is_recomputing()))
return input
model = Detect()
input = torch.rand(1, requires_grad=True)
output = model(input)
output.backward()
assert logs == [(False, False)]
def test_non_grad_output():
class ForkNonGrad(nn.Module):
def forward(self, input):
return (input * 2, torch.rand(1))
model = ForkNonGrad()
input = torch.rand(1, requires_grad=True)
output = checkpoint(model, input)
output[0].backward()
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from fairscale.nn.pipe.copy import Copy, Wait
from fairscale.nn.pipe.stream import CPUStream, current_stream, get_device, is_cuda, new_stream, use_stream
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def _test_copy_wait(prev_stream, next_stream, cuda_sleep=None):
device = get_device(prev_stream)
with use_stream(prev_stream):
if is_cuda(prev_stream):
cuda_sleep(0.5)
x = torch.ones(100, device=device, requires_grad=True)
(y,) = Copy.apply(prev_stream, next_stream, x)
(y,) = Wait.apply(prev_stream, next_stream, x)
with use_stream(next_stream):
assert torch.allclose(y.sum(), torch.tensor(100.0, device=device))
y.norm().backward()
with use_stream(prev_stream):
assert torch.allclose(x.grad.sum(), torch.tensor(10.0, device=device))
def test_copy_wait_cpu_cpu():
prev_stream = CPUStream
next_stream = CPUStream
_test_copy_wait(prev_stream, next_stream)
@skip_if_no_cuda
def test_copy_wait_cpu_cuda(cuda_sleep):
prev_stream = CPUStream
next_stream = current_stream(torch.device("cuda"))
_test_copy_wait(prev_stream, next_stream, cuda_sleep)
@skip_if_no_cuda
def test_copy_wait_cuda_cpu(cuda_sleep):
prev_stream = current_stream(torch.device("cuda"))
next_stream = CPUStream
_test_copy_wait(prev_stream, next_stream, cuda_sleep)
@skip_if_no_cuda
def test_copy_wait_cuda_cuda(cuda_sleep):
prev_stream = current_stream(torch.device("cuda"))
next_stream = new_stream(torch.device("cuda"))
_test_copy_wait(prev_stream, next_stream, cuda_sleep)
def test_wait_multiple_tensors():
a = torch.rand(1, requires_grad=True)
b = torch.rand(1, requires_grad=True)
a, b = Wait.apply(CPUStream, CPUStream, a, b)
assert a.grad_fn is b.grad_fn
assert a.grad_fn.__class__ is Wait._backward_cls
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from itertools import chain
import pytest
import torch
from torch import nn, optim
from fairscale.nn.pipe.batchnorm import DeferredBatchNorm
CHUNKS = 4
def tilt_dist(input):
# Tilt variance by channel.
rgb = input.transpose(0, 1)
rgb[0] *= 1
rgb[1] *= 10
rgb[2] *= 100
# Tilt mean by single batch.
for i, single in enumerate(input):
single += 2 ** i
return input
def chunked_forward(model, input, chunks=CHUNKS):
output_chunks = []
for chunk in input.chunk(chunks):
output_chunks.append(model(chunk))
return torch.cat(output_chunks)
@pytest.mark.parametrize("chunks", [1, 4])
@pytest.mark.parametrize("input_requires_grad", [True, False])
def test_transparency(chunks, input_requires_grad):
bn = nn.BatchNorm2d(3)
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=chunks)
input1 = torch.rand(16, 3, 224, 224)
input1 = tilt_dist(input1)
input2 = input1.clone()
input1.requires_grad = input_requires_grad
input2.requires_grad = input_requires_grad
output1 = chunked_forward(bn, input1, chunks=chunks)
output2 = chunked_forward(dbn, input2, chunks=chunks)
assert torch.allclose(output1, output2, atol=1e-4)
output1.mean().backward()
output2.mean().backward()
assert torch.allclose(bn.weight.grad, dbn.weight.grad, atol=1e-4)
if input_requires_grad:
assert input1.grad is not None
assert input2.grad is not None
assert torch.allclose(input1.grad, input2.grad, atol=1e-4)
@pytest.mark.parametrize("momentum", [0.1, None])
def test_running_stats(momentum):
bn = nn.BatchNorm2d(3, momentum=momentum)
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)
input = torch.rand(16, 3, 224, 224)
input = tilt_dist(input)
bn(input)
chunked_forward(dbn, input)
assert torch.allclose(bn.running_mean, dbn.running_mean, atol=1e-4)
assert torch.allclose(bn.running_var, dbn.running_var, atol=1e-4)
def test_convert_deferred_batch_norm():
bn = nn.BatchNorm2d(3, track_running_stats=False)
bn = DeferredBatchNorm.convert_deferred_batch_norm(bn, chunks=CHUNKS)
assert type(bn) is nn.BatchNorm2d # because of track_running_stats=False
dbn = DeferredBatchNorm(3, chunks=CHUNKS)
dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS)
assert dbn is dbn_again
dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS + 1)
assert dbn is not dbn_again # because of different chunks
def test_eval():
bn = nn.BatchNorm2d(3)
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)
input = torch.rand(16, 3, 224, 224)
input = tilt_dist(input)
bn(input)
chunked_forward(dbn, input)
bn.eval()
dbn.eval()
assert torch.allclose(bn(input), dbn(input), atol=1e-4)
def test_optimize():
bn = nn.BatchNorm2d(3)
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)
opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=1.0)
for i in range(5):
input = torch.rand(16, 3, 224, 224)
input = tilt_dist(input)
# train
y = bn(input)
a = y.sum()
a.backward()
y = chunked_forward(dbn, input)
b = y.sum()
b.backward()
opt.step()
# eval
bn.eval()
dbn.eval()
with torch.no_grad():
assert torch.allclose(bn(input), dbn(input), atol=1e-1 * (10 ** i))
def test_conv_bn():
bn = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3))
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)
input = torch.rand(16, 3, 224, 224)
input = tilt_dist(input)
opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=0.1)
# 1st step
a = bn(input)
b = chunked_forward(dbn, input)
# Outputs are different. (per-mini-batch vs. per-micro-batch)
assert not torch.allclose(a, b)
a.sum().backward()
b.sum().backward()
opt.step()
opt.zero_grad()
# Conv layers are also trained differently because of their different outputs.
assert not torch.allclose(bn[0].weight, dbn[0].weight)
# But BNs track identical running stats.
assert torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4)
assert torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3)
# 2nd step
a = bn(input)
b = chunked_forward(dbn, input)
a.sum().backward()
b.sum().backward()
# BNs can't track identical running stats due to the different conv layers.
assert not torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4)
assert not torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3)
def test_input_requiring_grad():
dbn = DeferredBatchNorm(3, chunks=CHUNKS)
input = torch.rand(16, 3, 224, 224, requires_grad=True)
input = tilt_dist(input)
chunked_forward(dbn, input)
assert not dbn.sum.requires_grad
assert dbn.sum.grad_fn is None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment