# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. # Copyright 2019 Kakao Brain # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest from torch import nn from fairscale.nn.pipe.skip import Namespace, skippable, verify_skippables def test_matching(): @skippable(stash=["foo"]) class Layer1(nn.Module): pass @skippable(pop=["foo"]) class Layer2(nn.Module): pass verify_skippables(nn.Sequential(Layer1(), Layer2())) def test_stash_not_pop(): @skippable(stash=["foo"]) class Layer1(nn.Module): pass with pytest.raises(TypeError) as e: verify_skippables(nn.Sequential(Layer1())) assert "no module declared 'foo' as poppable but stashed" in str(e.value) def test_pop_unknown(): @skippable(pop=["foo"]) class Layer1(nn.Module): pass with pytest.raises(TypeError) as e: verify_skippables(nn.Sequential(Layer1())) assert "'0' declared 'foo' as poppable but it was not stashed" in str(e.value) def test_stash_again(): @skippable(stash=["foo"]) class Layer1(nn.Module): pass @skippable(stash=["foo"]) class Layer2(nn.Module): pass @skippable(pop=["foo"]) class Layer3(nn.Module): pass with pytest.raises(TypeError) as e: verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) assert "'1' redeclared 'foo' as stashable" in str(e.value) def test_pop_again(): @skippable(stash=["foo"]) class Layer1(nn.Module): pass @skippable(pop=["foo"]) class Layer2(nn.Module): pass @skippable(pop=["foo"]) class Layer3(nn.Module): pass with pytest.raises(TypeError) as e: verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) assert "'2' redeclared 'foo' as poppable" in str(e.value) def test_stash_pop_together_different_names(): @skippable(stash=["foo"]) class Layer1(nn.Module): pass @skippable(pop=["foo"], stash=["bar"]) class Layer2(nn.Module): pass @skippable(pop=["bar"]) class Layer3(nn.Module): pass verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) def test_stash_pop_together_same_name(): @skippable(stash=["foo"], pop=["foo"]) class Layer1(nn.Module): pass with pytest.raises(TypeError) as e: verify_skippables(nn.Sequential(Layer1())) assert "'0' declared 'foo' both as stashable and as poppable" in str(e.value) def test_double_stash_pop(): @skippable(stash=["foo"]) class Layer1(nn.Module): pass @skippable(pop=["foo"]) class Layer2(nn.Module): pass @skippable(stash=["foo"]) class Layer3(nn.Module): pass @skippable(pop=["foo"]) class Layer4(nn.Module): pass with pytest.raises(TypeError) as e: verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3(), Layer4())) assert "'2' redeclared 'foo' as stashable" in str(e.value) assert "'3' redeclared 'foo' as poppable" in str(e.value) def test_double_stash_pop_but_isolated(): @skippable(stash=["foo"]) class Layer1(nn.Module): pass @skippable(pop=["foo"]) class Layer2(nn.Module): pass @skippable(stash=["foo"]) class Layer3(nn.Module): pass @skippable(pop=["foo"]) class Layer4(nn.Module): pass ns1 = Namespace() ns2 = Namespace() verify_skippables( nn.Sequential( Layer1().isolate(ns1), Layer2().isolate(ns1), Layer3().isolate(ns2), Layer4().isolate(ns2), ) )