"vscode:/vscode.git/clone" did not exist on "62f8eb48b1fa121d185ba3226f093d8f11cc9183"
Commit 0cd65242 authored by Mandeep Singh Baines's avatar Mandeep Singh Baines
Browse files

Initial commit

parents
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import weakref
import pytest
import torch
from fairscale.nn.pipe.dependency import Fork, Join, fork, join
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_fork_join():
logs = []
class Log(torch.autograd.Function):
@staticmethod
def forward(ctx, number, tensor):
ctx.number = number
return tensor.detach()
@staticmethod
def backward(ctx, grad):
logs.append(ctx.number)
return None, grad
a = torch.rand(1, device="cpu", requires_grad=True)
b = torch.rand(1, device="cuda", requires_grad=True)
a = Log.apply(1, a)
a, phony = fork(a)
b = join(a, phony)
b = Log.apply(2, b)
b = b.to("cpu")
(a + b).backward()
assert logs == [2, 1]
def test_fork_join_enable_grad():
x = torch.rand(1, requires_grad=True)
with torch.enable_grad():
x2, p = fork(x)
assert p.requires_grad
assert x2 is not x
x = x2
assert x.requires_grad
assert p.requires_grad
assert x.grad_fn.__class__ is Fork._backward_cls
assert p.grad_fn.__class__ is Fork._backward_cls
with torch.enable_grad():
x2 = join(x, p)
assert x2 is not x
x = x2
assert x.requires_grad
assert x.grad_fn.__class__ is Join._backward_cls
def test_fork_join_no_grad(monkeypatch):
def do_not_apply(*args):
raise AssertionError("Function.apply called")
monkeypatch.setattr("torch.autograd.Function.apply", do_not_apply)
x = torch.rand(1, requires_grad=True)
with torch.no_grad():
x2, p = fork(x)
assert not p.requires_grad
assert x2 is x
x = x2
with torch.no_grad():
x2 = join(x, p)
assert x2 is x
x = x2
def test_fork_leak():
leak = None
class F(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad):
nonlocal leak
leak = weakref.ref(ctx)
return grad
x = torch.rand(1, requires_grad=True)
x = F.apply(x)
x, phony = fork(x)
x = join(x, phony)
x.backward()
del x, phony
assert leak() is None
def test_join_when_fork_not_requires_grad():
x = torch.rand(2, 1)
a, b = x.chunk(2)
assert not a.requires_grad
a, p = fork(a)
assert not a.requires_grad
assert not p.requires_grad
assert not b.requires_grad
b = join(b, p)
assert not b.requires_grad
def test_join_when_fork_requires_grad():
x = torch.rand(2, 1)
a, b = x.chunk(2)
a.requires_grad_()
assert a.requires_grad
a, p = fork(a)
assert a.requires_grad
assert p.requires_grad
assert not b.requires_grad
b = join(b, p)
assert b.requires_grad
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from torch import nn
from fairscale.nn.pipe import Pipe
def test_inplace_on_requires_grad():
model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True))
model = Pipe(model, [1, 1], devices=["cpu", "cpu"], checkpoint="always")
x = torch.rand(1)
y = model(x)
message = r"a leaf Variable that requires grad .* used in an in-place operation."
with pytest.raises(RuntimeError, match=message):
y.backward()
@pytest.mark.xfail(strict=True)
def test_inplace_on_not_requires_grad():
# In-place operation on a tensor not requiring grad doesn't cause a
# RuntimeError. Currently, we cannot detect this case.
model = nn.Sequential(nn.ReLU(inplace=True))
model = Pipe(model, [1], devices=["cpu"], checkpoint="always")
x = torch.rand(1)
y = model(x)
del model
message = r"a leaf Variable that requires grad .* used in an in-place operation."
with pytest.raises(RuntimeError, match=message):
y.backward()
@pytest.mark.xfail(strict=True)
def test_inplace_incorrect_grad():
class M(nn.Module):
def forward(self, foo_bar):
# 'foo' requires grad but 'bar' does not. In-place operation on
# 'bar' won't cause a RuntimeError.
foo, bar = foo_bar
# add_(1) is not idempotent, in contrast to relu_(). If it is
# executed multiple times, it will accumulates each difference onto
# 'bar'.
bar.add_(1)
# 'bar' is still captured by checkpointing. 'foo' will get
# incorrect grad.
return foo * bar
model = nn.Sequential(M())
model = Pipe(model, [1], devices=["cpu"], checkpoint="always")
foo = torch.tensor([1.0], requires_grad=True)
bar = torch.tensor([1.0])
output = model((foo, bar))
del model
output.backward()
# The gradient of 'foo' should be 2, but it is 3 actually because
# bar.add_(1) was executed twice due to checkpointing.
assert foo.grad.item() == 2.0
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
import torch.cuda
from fairscale.nn.pipe.microbatch import Batch, check, gather, scatter
def test_batch_atomic():
x = torch.tensor(42)
b = Batch(x)
assert b.atomic
assert b.tensor is x
with pytest.raises(AttributeError):
b.tensors
assert list(b) == [x]
assert len(b) == 1
assert b[0] is x
def test_batch_non_atomic():
x, y = torch.tensor(42), torch.tensor(21)
b = Batch((x, y))
assert not b.atomic
with pytest.raises(AttributeError):
b.tensor
assert b.tensors == (x, y)
assert list(b) == [x, y]
assert len(b) == 2
assert b[0] is x
assert b[1] is y
def test_batch_call():
a = Batch(torch.tensor(42))
b = Batch((torch.tensor(42), torch.tensor(21)))
def f(x):
return x
assert a.call(f).atomic
assert not b.call(f).atomic
def test_batch_setitem_by_index():
a = Batch(torch.tensor(42))
b = Batch((torch.tensor(42), torch.tensor(21)))
a[0] = torch.tensor(0)
b[0] = torch.tensor(0)
assert a.atomic
assert a[0].item() == 0
assert not b.atomic
assert len(b) == 2
assert b[0].item() == 0
assert b[1].item() == 21
def test_batch_setitem_by_slice():
a = Batch(torch.tensor(42))
b = Batch((torch.tensor(42), torch.tensor(21)))
a[:] = (torch.tensor(0),)
b[:] = (torch.tensor(0),)
assert a.atomic
assert a[0].item() == 0
assert not b.atomic
assert len(b) == 1
assert b[0].item() == 0
def test_check():
check(torch.tensor(42))
check((torch.tensor(4), torch.tensor(2)))
with pytest.raises(TypeError):
check(42)
with pytest.raises(TypeError):
check("str")
with pytest.raises(TypeError):
check((torch.tensor(4), 2))
def test_gather_tensors():
a = torch.zeros(1, 1)
b = torch.zeros(1, 1)
ab = gather([Batch(a), Batch(b)])
assert ab.size() == (2, 1)
def test_gather_tuples():
a = (torch.zeros(1, 1), torch.zeros(2, 2))
b = (torch.zeros(1, 1), torch.zeros(2, 2))
ab = gather([Batch(a), Batch(b)])
assert isinstance(ab, tuple)
assert ab[0].size() == (2, 1)
assert ab[1].size() == (4, 2)
def test_scatter_tensor():
ab = torch.zeros(2, 1)
a, b = scatter(ab, chunks=2)
assert a.tensor.size() == (1, 1)
assert b.tensor.size() == (1, 1)
def test_scatter_tuple():
ab = (torch.zeros(2, 1), torch.zeros(4, 2))
a, b = scatter(ab, chunks=2)
assert a.tensors[0].size() == (1, 1)
assert b.tensors[0].size() == (1, 1)
assert a.tensors[1].size() == (2, 2)
assert b.tensors[1].size() == (2, 2)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from fairscale.nn.pipe.phony import get_phony
def test_phony_size():
p = get_phony(torch.device("cpu"), requires_grad=False)
assert p.size() == (0,)
def test_phony_requires_grad():
p1 = get_phony(torch.device("cpu"), requires_grad=True)
p2 = get_phony(torch.device("cpu"), requires_grad=False)
assert p1.requires_grad
assert not p2.requires_grad
def test_cached_phony():
p1 = get_phony(torch.device("cpu"), requires_grad=True)
p2 = get_phony(torch.device("cpu"), requires_grad=True)
assert p1 is p2
p3 = get_phony(torch.device("cpu"), requires_grad=False)
p4 = get_phony(torch.device("cpu"), requires_grad=False)
assert p3 is p4
assert p1 is not p3
def test_phony_in_autograd_function():
class Phonify(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
phony = get_phony(input.device, requires_grad=False)
return phony.detach()
x = torch.rand(1, requires_grad=True)
p1 = Phonify.apply(x)
p2 = get_phony(torch.device("cpu"), requires_grad=True)
assert p1 is not p2
assert p1.grad_fn is not None
assert p2.grad_fn is None
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from copy import deepcopy
import time
import pytest
import torch
from torch import nn
from fairscale.nn.pipe import Pipe
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_parameters():
model = nn.Sequential(nn.Linear(1, 1))
pipe = Pipe(model, balance=[1], devices=["cpu"], chunks=1)
assert list(pipe.parameters()) != []
def test_public_attrs():
class MyString:
def __init__(self, value):
self.value = value
def __str__(self):
return self.value
model = nn.Sequential(nn.Linear(1, 1))
pipe = Pipe(model, balance=(1,), devices=("cpu",), chunks=42.000, checkpoint=MyString("always"))
assert pipe.balance == [1]
assert pipe.devices == [torch.device("cpu")]
assert pipe.chunks == 42
assert isinstance(pipe.chunks, int)
assert pipe.checkpoint == "always"
assert isinstance(pipe.checkpoint, str)
@pytest.mark.parametrize("balance", [[2], [1, 1]])
def test_sequential_like(balance):
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(a, b)
model = Pipe(model, balance, devices=["cpu", "cpu"])
assert len(model) == 2
assert list(model) == [a, b]
assert model[0] is a
assert model[1] is b
with pytest.raises(IndexError):
_ = model[2]
assert model[-1] is b
assert model[-2] is a
def test_balance_wrong_length():
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(a, b)
with pytest.raises(ValueError):
Pipe(model, balance=[1])
with pytest.raises(ValueError):
Pipe(model, balance=[3])
def test_balance_less_than_1():
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(a, b)
with pytest.raises(ValueError):
Pipe(model, balance=[0, 2])
with pytest.raises(ValueError):
Pipe(model, balance=[-1, 3])
def test_chunks_less_than_1():
model = nn.Sequential(nn.Linear(1, 1))
with pytest.raises(ValueError):
Pipe(model, balance=[1], devices=["cpu"], chunks=0)
with pytest.raises(ValueError):
Pipe(model, balance=[1], devices=["cpu"], chunks=-1)
def test_too_few_devices():
model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1))
with pytest.raises(IndexError):
# len(balance) > len(devices)
model = Pipe(model, balance=[1, 1, 1, 1], devices=["cpu"])
def test_batch_size_indivisible():
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], devices=["cpu"], chunks=4)
with pytest.warns(None) as record:
model(torch.rand(7, 1))
# Indivisible batch size is legal.
assert not record
def test_batch_size_small():
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], devices=["cpu"], chunks=4)
with pytest.warns(None) as record:
model(torch.rand(2, 1))
# Batch size smaller than chunks is legal.
assert not record
def test_checkpoint_mode():
def count_grad_fn(grad_fn, name, visited=set()):
if grad_fn in visited:
return 0
visited.add(grad_fn)
if grad_fn is None:
return 0
if grad_fn.__class__.__name__ == name:
return 1
counter = 0
for next_grad_fn, _ in grad_fn.next_functions:
counter += count_grad_fn(next_grad_fn, name, visited=visited)
return counter
model = nn.Sequential(nn.Linear(1, 1))
input = torch.rand(2, 1)
always = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="always")
except_last = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="except_last")
never = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="never")
always_output = always(input)
except_last_output = except_last(input)
never_output = never(input)
assert count_grad_fn(always_output.grad_fn, "CheckpointBackward") == 2
assert count_grad_fn(except_last_output.grad_fn, "CheckpointBackward") == 1
assert count_grad_fn(never_output.grad_fn, "CheckpointBackward") == 0
def test_checkpoint_mode_invalid():
model = nn.Sequential(nn.Linear(1, 1))
with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"):
Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="INVALID_CHECKPOINT")
def test_checkpoint_mode_when_chunks_1():
model = nn.Sequential(nn.Linear(1, 1))
# All checkpoint modes are fine.
Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="except_last")
Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="always")
Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="never")
def test_checkpoint_eval():
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], devices=["cpu"], chunks=2)
input = torch.rand(2, 1)
def find_grad_fn(grad_fn, name):
if grad_fn is None:
return False
if grad_fn.__class__.__name__ == name:
return True
for next_grad_fn, _ in grad_fn.next_functions:
if find_grad_fn(next_grad_fn, name):
return True
return False
model.train()
train_output = model(input)
assert find_grad_fn(train_output.grad_fn, "CheckpointBackward")
assert find_grad_fn(train_output.grad_fn, "RecomputeBackward")
model.eval()
eval_output = model(input)
assert not find_grad_fn(eval_output.grad_fn, "CheckpointBackward")
assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward")
def test_checkpoint_non_float_input():
class ForkNonFloat(nn.Module):
def forward(self, input):
return (input * 2, torch.tensor([False]))
class JoinNonFloat(nn.Module):
def forward(self, input):
return input[0] * 2
model = nn.Sequential(ForkNonFloat(), JoinNonFloat())
model = Pipe(model, balance=[1, 1], devices=["cpu", "cpu"], chunks=1, checkpoint="always")
input = torch.rand(1, requires_grad=True)
output = model(input)
output.backward()
def test_no_grad():
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], devices=["cpu"], chunks=2)
input = torch.rand(2, 1)
latent = None
def hook(module, input, output):
_ = module
_ = input
nonlocal latent
latent = output
partition = model.partitions[0]
partition.register_forward_hook(hook)
with torch.no_grad():
model(input)
assert latent.grad_fn is None
def test_exception():
class ExpectedException(Exception):
pass
class Raise(nn.Module):
def forward(self, *_):
raise ExpectedException()
model = nn.Sequential(Raise())
model = Pipe(model, balance=[1], devices=["cpu"], chunks=1)
with pytest.raises(ExpectedException):
model(torch.rand(1))
def test_exception_early_stop_asap():
"""Even the first partitions have finished to process, the partition before
the failed partition should be killed as soon as possible.
"""
class ExpectedException(Exception):
pass
class Pass(nn.Module):
def forward(self, x):
return x
counter = 0
class Counter(nn.Module):
def forward(self, x):
time.sleep(0.1)
nonlocal counter
counter += 1
return x
class Raise(nn.Module):
def forward(self, x):
raise ExpectedException()
model = nn.Sequential(Pass(), Pass(), Counter(), Raise())
model = Pipe(model, [1, 1, 1, 1], devices=["cpu", "cpu", "cpu", "cpu"], chunks=3)
with pytest.raises(ExpectedException):
model(torch.rand(3))
# If the early stop doesn't work, it would be 3 instead.
assert counter == 2
def test_input_pair():
class Two(nn.Module):
def __init__(self):
super().__init__()
self.fc_a = nn.Linear(1, 1)
self.fc_b = nn.Linear(1, 1)
def forward(self, a_and_b):
a, b = a_and_b
return (self.fc_a(a), self.fc_b(b))
model = nn.Sequential(Two())
model = Pipe(model, balance=[1], devices=["cpu"], chunks=2)
a = torch.rand(10, 1, requires_grad=True)
b = torch.rand(10, 1, requires_grad=True)
a_out, b_out = model((a, b))
loss = (a_out + b_out).mean()
loss.backward()
assert a.grad is not None
assert b.grad is not None
def test_input_singleton():
class One(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(1, 1)
def forward(self, only_a):
(a,) = only_a
return (self.fc(a),)
model = nn.Sequential(One())
model = Pipe(model, balance=[1], devices=["cpu"], chunks=2)
a = torch.rand(10, 1, requires_grad=True)
(a_out,) = model((a,))
loss = a_out.mean()
loss.backward()
assert all(p.grad is not None for p in model.parameters())
assert a.grad is not None
def test_input_varargs():
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], devices=["cpu"])
a = torch.rand(1)
b = torch.rand(1)
# TypeError: forward() takes 2 positional arguments but 3 were given
with pytest.raises(TypeError):
model(a, b)
def test_non_tensor():
class NonTensor(nn.Module):
def forward(self, _):
return "hello"
model = nn.Sequential(NonTensor())
model = Pipe(model, balance=[1], devices=["cpu"])
x = torch.rand(1)
# TypeError: expected Tensor as element 0 in argument 0, but got str
with pytest.raises(TypeError):
model(x)
# TypeError: expected Tensor to scatter, but got str
with pytest.raises(TypeError):
model("hello")
def test_non_tensor_tuple():
class NonTensorTuple(nn.Module):
def forward(self, x):
return (x, "hello")
model = nn.Sequential(NonTensorTuple())
model = Pipe(model, balance=[1], devices=["cpu"])
x = torch.rand(1)
# TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1
with pytest.raises(TypeError):
model(x)
# TypeError: expected Tensor to scatter, but got str
with pytest.raises(TypeError):
model((x, "hello"))
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_deferred_batch_norm(checkpoint):
bn = nn.BatchNorm2d(3)
pipe_bn = deepcopy(bn)
pipe = Pipe(
nn.Sequential(pipe_bn), balance=[1], devices=["cpu"], chunks=2, checkpoint=checkpoint, deferred_batch_norm=True
)
x = torch.rand(4, 3, 10, 10)
pipe(x).mean().backward()
bn(x).mean().backward()
assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4)
assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4)
@pytest.mark.parametrize("checkpoint", ["never", "always"])
def test_deferred_batch_norm_params(checkpoint):
bn = nn.BatchNorm2d(3)
pipe_bn = deepcopy(bn)
pipe = Pipe(
nn.Sequential(pipe_bn), balance=[1], devices=["cpu"], chunks=1, checkpoint=checkpoint, deferred_batch_norm=True
)
x = torch.rand(4, 3, 10, 10)
pipe(x).mean().backward()
bn(x).mean().backward()
assert pipe[0].weight.grad is not None
assert pipe[0].bias.grad is not None
assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4)
assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4)
def test_devices():
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
c = nn.Linear(1, 1)
# There are extra two devices.
devices = ["cpu", "cpu", "cpu", "cpu", "cpu"]
model = nn.Sequential(a, b, c)
model = Pipe(model, [1, 1, 1], devices=devices)
cpu = torch.device("cpu")
# Extra devices must be discarded.
assert model.devices == [cpu, cpu, cpu]
def test_partitions():
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(a, b)
model = Pipe(model, [1, 1], devices=["cpu", "cpu"])
assert isinstance(model.partitions, nn.ModuleList)
assert isinstance(model.partitions[0], nn.Sequential)
assert isinstance(model.partitions[1], nn.Sequential)
assert "partitions.0.0.weight" in model.state_dict()
def test_deny_moving():
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(a, b)
model = Pipe(model, [1, 1], devices=["cpu", "cpu"])
# Moving is denied.
with pytest.raises(TypeError):
model.cuda()
with pytest.raises(TypeError):
model.cpu()
with pytest.raises(TypeError):
model.to(torch.device("cuda"))
with pytest.raises(TypeError):
model.to(0)
with pytest.raises(TypeError):
model.to("cuda")
with pytest.raises(TypeError):
model.to(device=0)
with pytest.raises(TypeError):
model.to(torch.rand(1))
with pytest.raises(TypeError):
model.to(tensor=torch.rand(1))
# Casting is allowed.
model.half()
model.to(torch.double)
model.to(dtype=torch.float)
def test_empty_module():
# Empty sequential module is not illegal.
model = nn.Sequential()
model = Pipe(model, [])
assert model(torch.tensor(42)) == torch.tensor(42)
assert model((torch.tensor(42),)) == (torch.tensor(42),)
# But only tensor or tensors is legal in Pipe.
with pytest.raises(TypeError):
model(42)
def test_named_children():
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(OrderedDict([("a", a), ("b", b)]))
model = Pipe(model, [1, 1], devices=["cpu", "cpu"])
names = set(n for n, _ in model.named_modules())
assert "partitions.0.a" in names
assert "partitions.1.b" in names
# Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires
# several methods in its namespace.
with pytest.raises(AttributeError):
model.a
def test_recommend_auto_balance():
with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
# balance is required
Pipe(nn.Sequential())
with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
# module and sum of balance have differen length (module: 0, sum of balance: 1)
Pipe(nn.Sequential(), [1])
with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
# module and sum of balance have different length (module: 2, sum of balance: 1)
Pipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1])
def test_verify_module_non_sequential():
with pytest.raises(TypeError, match="module must be nn.Sequential to be partitioned"):
Pipe(nn.Module(), [1])
def test_verify_module_duplicate_children():
conv = nn.Conv2d(3, 3, 1)
model = nn.Sequential(conv, conv)
with pytest.raises(ValueError, match="module with duplicate children is not supported"):
Pipe(model, [1, 1])
@skip_if_no_cuda
def test_verify_module_duplicate_parameters_on_distinct_devices():
class Surrogate(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
conv = nn.Conv2d(3, 3, 1)
model = nn.Sequential(Surrogate(conv), Surrogate(conv))
with pytest.raises(ValueError, match="module with duplicate parameters on distinct devices is not supported"):
Pipe(model, [1, 1], devices=["cpu", "cuda"])
def test_verify_module_duplicate_parameters_on_same_device():
class Surrogate(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
conv = nn.Conv2d(3, 3, 1)
model = nn.Sequential(Surrogate(conv), Surrogate(conv))
Pipe(model, [1, 1], devices=["cpu", "cpu"])
def test_forward_lockstep():
timeline = []
class DelayedLog(nn.Module):
def __init__(self, j, seconds):
super().__init__()
self.i = 0
self.j = j
self.seconds = seconds
def forward(self, x):
time.sleep(self.seconds)
timeline.append((self.i, self.j))
self.i += 1
return x
model = nn.Sequential(DelayedLog(0, seconds=0), DelayedLog(1, seconds=0.1))
model = Pipe(model, balance=[1, 1], devices=["cpu", "cpu"], chunks=3)
model(torch.rand(3, 1))
# Expected timeline: (Logs are recorded at !)
#
# Partition #0: 0! 1! 2!
# Partition #1: 000! 111! 222!
#
assert timeline == [(0, 0), (1, 0), (0, 1), (2, 0), (1, 1), (2, 1)]
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fairscale.nn.pipe.pipeline import clock_cycles
def test_clock_cycles():
assert list(clock_cycles(1, 1)) == [[(0, 0)]]
assert list(clock_cycles(1, 3)) == [[(0, 0)], [(0, 1)], [(0, 2)]]
assert list(clock_cycles(3, 1)) == [[(0, 0)], [(1, 0)], [(2, 0)]]
assert list(clock_cycles(3, 3)) == [ # noqa
[(0, 0)],
[(1, 0), (0, 1)],
[(2, 0), (1, 1), (0, 2)],
[(2, 1), (1, 2)],
[(2, 2)],
]
assert list(clock_cycles(4, 2)) == [ # noqa
[(0, 0)],
[(1, 0), (0, 1)],
[(2, 0), (1, 1)],
[(3, 0), (2, 1)],
[(3, 1)],
]
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from fairscale.nn.pipe.stream import (
CPUStream,
current_stream,
default_stream,
get_device,
is_cuda,
new_stream,
record_stream,
use_device,
use_stream,
wait_stream,
)
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
class TestNewStream:
def test_new_stream_cpu(self):
stream = new_stream(torch.device("cpu"))
assert stream is CPUStream
@skip_if_no_cuda
def test_new_stream_cuda(self):
stream = new_stream(torch.device("cuda"))
assert isinstance(stream, torch.cuda.Stream)
assert stream != torch.cuda.default_stream()
class TestCurrentStream:
def test_current_stream_cpu(self):
stream = current_stream(torch.device("cpu"))
assert stream is CPUStream
@skip_if_no_cuda
def test_current_stream_cuda(self):
stream = current_stream(torch.device("cuda"))
assert isinstance(stream, torch.cuda.Stream)
assert stream == torch.cuda.current_stream()
class TestDefaultStream:
def test_default_stream_cpu(self):
stream = default_stream(torch.device("cpu"))
assert stream is CPUStream
@skip_if_no_cuda
def test_default_stream_cuda(self):
stream = default_stream(torch.device("cuda"))
assert isinstance(stream, torch.cuda.Stream)
assert stream == torch.cuda.default_stream()
class TestUseDevice:
def test_use_device_cpu(self):
with use_device(torch.device("cpu")):
pass
@skip_if_no_cuda
def test_use_device_cuda(self):
with use_device(torch.device("cuda")):
pass
class TestUseStream:
def test_use_stream_cpu(self):
with use_stream(CPUStream):
pass
@skip_if_no_cuda
def test_use_stream_cuda(self):
stream = new_stream(torch.device("cuda"))
with use_stream(stream):
assert current_stream(torch.device("cuda")) == stream
class TestGetDevice:
def test_get_device_cpu(self):
assert get_device(CPUStream).type == "cpu"
@skip_if_no_cuda
def test_get_device_cuda(self):
stream = current_stream(torch.device("cuda"))
assert get_device(stream).type == "cuda"
class TestWaitStream:
def _test_wait_stream(self, source, target, cuda_sleep=None):
with use_stream(target):
if is_cuda(target):
cuda_sleep(0.5)
x = torch.ones(100, 100, device=get_device(target))
wait_stream(source, target)
with use_stream(source):
assert x.sum().item() == 10000
def test_wait_stream_cpu_cpu(self):
source = CPUStream
target = CPUStream
self._test_wait_stream(source, target)
@skip_if_no_cuda
def test_wait_stream_cpu_cuda(self, cuda_sleep):
source = CPUStream
target = new_stream(torch.device("cuda"))
self._test_wait_stream(source, target, cuda_sleep)
@skip_if_no_cuda
def test_wait_stream_cuda_cpu(self, cuda_sleep):
source = new_stream(torch.device("cuda"))
target = CPUStream
self._test_wait_stream(source, target, cuda_sleep)
@skip_if_no_cuda
def test_wait_stream_cuda_cuda(self, cuda_sleep):
source = current_stream(torch.device("cuda"))
target = new_stream(torch.device("cuda"))
self._test_wait_stream(source, target, cuda_sleep)
class TestRecordStream:
def test_record_stream_cpu(self):
# It should silently ignore CPU tensors.
x = torch.rand(1, device=torch.device("cpu"))
record_stream(x, CPUStream)
@skip_if_no_cuda
def test_record_stream_cuda(self, cuda_sleep):
# This test detects unexpected block reallocation. For reliable test,
# the stream to allocate tensors is isolated. The allocator will not
# reuse free blocks which were allocated from another stream.
stream_alloc = new_stream(torch.device("cuda"))
with torch.cuda.stream(stream_alloc):
x = torch.rand(1, device=torch.device("cuda"))
stream = new_stream(torch.device("cuda"))
record_stream(x, stream)
with use_stream(stream):
cuda_sleep(0.5)
# 'x' is deleted at Python's perspective. But the block of 'x' is still
# required for 'stream'. 'y' shouldn't be allocated to the block.
data_ptr = x.data_ptr()
del x
stream_alloc.synchronize()
with torch.cuda.stream(stream_alloc):
y = torch.rand(1, device=torch.device("cuda"))
assert y.data_ptr() != data_ptr
# Pause Python until 'stream' finishes tasks queued. Now the block of
# 'x' is free to be reallocated.
wait_stream(CPUStream, stream)
with torch.cuda.stream(stream_alloc):
z = torch.rand(1, device=torch.device("cuda"))
assert z.data_ptr() == data_ptr
@skip_if_no_cuda
def test_record_stream_shifted_view(self, cuda_sleep):
# Issue: https://github.com/pytorch/pytorch/issues/27366
stream_alloc = new_stream(torch.device("cuda"))
with torch.cuda.stream(stream_alloc):
x = torch.rand(2, device=torch.device("cuda"))
y = x[1:]
assert y.data_ptr() > x.data_ptr()
stream = new_stream(torch.device("cuda"))
with use_stream(stream):
cuda_sleep(0.5)
record_stream(y, stream)
data_ptr = x.data_ptr()
del x, y
stream_alloc.synchronize()
with torch.cuda.stream(stream_alloc):
z = torch.rand(2, device=torch.device("cuda"))
assert z.data_ptr() != data_ptr
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from fairscale.nn import Pipe
def test_simple_linears():
def sum_grad(parameters):
return sum([p.grad.sum() for p in parameters if p.grad is not None])
def zero_grad(parameters):
for p in parameters:
p.grad = None
inputs = torch.rand(8, 1)
model = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 4), nn.Linear(4, 2), nn.Linear(2, 1),)
# Without Pipe
outputs = model(inputs)
loss = outputs.mean()
loss.backward()
grad_without_pipe = sum_grad(model.parameters())
zero_grad(model.parameters())
# With Pipe
model = Pipe(model, [2, 2], devices=["cpu", "cpu"], chunks=4)
outputs = model(inputs)
loss = outputs.mean()
loss.backward()
grad_with_pipe = sum_grad(model.parameters())
# Both grads should be identical.
assert torch.allclose(grad_with_pipe, grad_without_pipe)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
import pytest
import torch
from fairscale.nn.pipe.microbatch import Batch
from fairscale.nn.pipe.stream import CPUStream
from fairscale.nn.pipe.worker import Task, spawn_workers
class fake_device:
"""A test double for :class:`torch.device`. Every fake device is different
with each other.
"""
type = "fake"
index = None
def test_join_running_workers():
count = 0
def counter():
nonlocal count
time.sleep(0.1)
count += 1
return Batch(())
with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues):
def call_in_worker(i, f):
task = Task(CPUStream, compute=f, finalize=None)
in_queues[i].put(task)
for i in range(10):
call_in_worker(i, counter)
# There's no nondeterminism because 'spawn_workers' joins all running
# workers.
assert count == 10
def test_join_running_workers_with_exception():
class ExpectedException(Exception):
pass
count = 0
def counter():
nonlocal count
time.sleep(0.1)
count += 1
return Batch(())
with pytest.raises(ExpectedException):
with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues):
def call_in_worker(i, f):
task = Task(CPUStream, compute=f, finalize=None)
in_queues[i].put(task)
for i in range(10):
call_in_worker(i, counter)
raise ExpectedException
# There's no nondeterminism because only 1 task can be placed in input
# queues.
assert count == 10
def test_compute_multithreading():
"""Task.compute should be executed on multiple threads."""
thread_ids = set()
def log_thread_id():
thread_id = threading.current_thread().ident
thread_ids.add(thread_id)
return Batch(())
with spawn_workers([fake_device() for _ in range(2)]) as (in_queues, out_queues):
for i in range(2):
t = Task(CPUStream, compute=log_thread_id, finalize=None)
in_queues[i].put(t)
for i in range(2):
out_queues[i].get()
assert len(thread_ids) == 2
def test_compute_success():
"""Task.compute returns (True, (task, batch)) on success."""
def _42():
return Batch(torch.tensor(42))
with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues):
t = Task(CPUStream, compute=_42, finalize=None)
in_queues[0].put(t)
ok, (task, batch) = out_queues[0].get()
assert ok
assert task is t
assert isinstance(batch, Batch)
assert batch[0].item() == 42
def test_compute_exception():
"""Task.compute returns (False, exc_info) on failure."""
def zero_div():
0 / 0
with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues):
t = Task(CPUStream, compute=zero_div, finalize=None)
in_queues[0].put(t)
ok, exc_info = out_queues[0].get()
assert not ok
assert isinstance(exc_info, tuple)
assert issubclass(exc_info[0], ZeroDivisionError)
@pytest.mark.parametrize("grad_mode", [True, False])
def test_grad_mode(grad_mode):
def detect_grad_enabled():
x = torch.rand(1, requires_grad=torch.is_grad_enabled())
return Batch(x)
with torch.set_grad_enabled(grad_mode):
with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues):
task = Task(CPUStream, compute=detect_grad_enabled, finalize=None)
in_queues[0].put(task)
ok, (_, batch) = out_queues[0].get()
assert ok
assert batch[0].requires_grad == grad_mode
def test_worker_per_device():
cpu = torch.device("cpu")
cpu0 = torch.device("cpu", index=0)
fake1 = fake_device()
fake2 = fake_device()
with spawn_workers([cpu, cpu, cpu0, fake1, fake2]) as (in_queues, out_queues):
assert len(in_queues) == len(out_queues) == 5
# 0: cpu, 1: cpu, 2: cpu0
assert in_queues[0] is in_queues[1] is in_queues[2]
assert out_queues[0] is out_queues[1] is out_queues[2]
# 3: fake1, 4: fake2
assert in_queues[3] is not in_queues[4]
assert out_queues[3] is not out_queues[4]
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import os
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import fairscale.optim as optim
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def setup_module(module):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend="nccl", rank=0, world_size=1)
def dist_init(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
def test_create():
params = [torch.rand(1)]
o = optim.OSS(params, lr=0.01)
@skip_if_no_cuda
def test_state_dict():
x = torch.tensor([1.0], device="cuda", requires_grad=True)
o = optim.OSS([x], lr=0.1)
state_dict = o.state_dict()
o = optim.OSS([x], lr=0.01)
o.load_state_dict(state_dict)
# We should now be using a lr of 0.1.
x.backward()
o.step()
assert x == torch.tensor([0.9], device="cuda")
def run_test_add_param_group(rank, world_size):
dist_init(rank, world_size)
params = []
for size in [4, 5, 2, 6, 4]:
params.append(torch.rand(size, 1))
o = optim.OSS(params, lr=0.1)
assert len(o.param_groups) == 1
o.add_param_group({"params": [torch.rand(3, 1)]})
assert len(o.param_groups) == 2
# Verify that added group is added to the correct partition making all have 8 elements.
assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8
if rank == 1:
len(o.optim.param_groups) == 2
else:
len(o.optim.param_groups) == 1
def test_add_param_group():
world_size = 3
mp.spawn(run_test_add_param_group, args=(world_size,), nprocs=world_size, join=True)
def run_test_zero_grad(rank, world_size):
dist_init(rank, world_size)
x = torch.rand(1)
m = torch.nn.Linear(1, 1)
o = optim.OSS(m.parameters(), lr=0.1)
y = m(x)
y.backward(x)
assert m.weight.grad
assert m.bias.grad
o.zero_grad()
assert not m.weight.grad
assert not m.bias.grad
@skip_if_no_cuda
def test_zero_grad():
world_size = 2
mp.spawn(run_test_zero_grad, args=(world_size,), nprocs=world_size, join=True)
def run_test_step(rank, world_size):
dist_init(rank, world_size)
x = torch.tensor([float(rank + 1)], device=rank)
m = torch.nn.Linear(1, 1)
m.weight.data = torch.tensor([[1.0]])
m.bias.data = torch.tensor([2.0])
m.to(rank)
o = optim.OSS(m.parameters(), lr=0.1)
y = m(x)
y.backward(x)
for p in m.parameters():
dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
p.grad.data /= world_size
o.step()
assert m.weight == torch.tensor([[0.75]], device=rank)
assert m.bias == torch.tensor([1.85], device=rank)
@skip_if_no_cuda
def test_step():
world_size = 2
mp.spawn(run_test_step, args=(world_size,), nprocs=world_size, join=True)
def run_test_step_with_closure(rank, world_size):
dist_init(rank, world_size)
x_val = rank + 1
weight = 1.0
bias = 2.0
error = 1.0
target = torch.tensor([x_val * weight + bias + error], device=rank)
loss_fn = torch.nn.L1Loss()
x = torch.tensor([float(x_val)], device=rank)
m = torch.nn.Linear(1, 1)
m.weight.data = torch.tensor([[weight]])
m.bias.data = torch.tensor([bias])
m.to(rank)
o = optim.OSS(m.parameters(), lr=0.1)
y = m(x)
y.backward(x)
for p in m.parameters():
dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
p.grad.data /= world_size
def closure():
o.zero_grad()
output = m(x)
loss = loss_fn(output, target)
loss.backward()
return loss
loss = o.step(closure=closure)
assert loss == torch.tensor(error, device=rank)
assert m.weight == torch.tensor([[1.1]], device=rank)
assert m.bias == torch.tensor([2.1], device=rank)
@skip_if_no_cuda
def test_step_with_closure():
world_size = 2
mp.spawn(run_test_step_with_closure, args=(world_size,), nprocs=world_size, join=True)
def run_test_sharding(rank, world_size):
dist_init(rank, world_size)
params = []
for size in [5, 4, 2, 6, 4, 3]:
params.append(torch.rand(size, 1))
o = optim.OSS(params, lr=0.1)
assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == 8
def test_sharding():
world_size = 3
mp.spawn(run_test_sharding, args=(world_size,), nprocs=world_size, join=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment