# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. # Copyright 2019 Kakao Brain # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from itertools import chain import pytest import torch from torch import nn, optim from fairscale.nn.pipe.batchnorm import DeferredBatchNorm CHUNKS = 4 def tilt_dist(input): # Tilt variance by channel. rgb = input.transpose(0, 1) rgb[0] *= 1 rgb[1] *= 10 rgb[2] *= 100 # Tilt mean by single batch. for i, single in enumerate(input): single += 2**i return input def chunked_forward(model, input, chunks=CHUNKS): output_chunks = [] for chunk in input.chunk(chunks): output_chunks.append(model(chunk)) return torch.cat(output_chunks) @pytest.mark.parametrize("chunks", [1, 4]) @pytest.mark.parametrize("input_requires_grad", [True, False]) def test_transparency(chunks, input_requires_grad): bn = nn.BatchNorm2d(3) dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=chunks) input1 = torch.rand(16, 3, 224, 224) input1 = tilt_dist(input1) input2 = input1.clone() input1.requires_grad = input_requires_grad input2.requires_grad = input_requires_grad output1 = chunked_forward(bn, input1, chunks=chunks) output2 = chunked_forward(dbn, input2, chunks=chunks) assert torch.allclose(output1, output2, atol=1e-4) output1.mean().backward() output2.mean().backward() assert torch.allclose(bn.weight.grad, dbn.weight.grad, atol=1e-4) if input_requires_grad: assert input1.grad is not None assert input2.grad is not None assert torch.allclose(input1.grad, input2.grad, atol=1e-4) @pytest.mark.parametrize("momentum", [0.1, None]) def test_running_stats(momentum): bn = nn.BatchNorm2d(3, momentum=momentum) dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) input = torch.rand(16, 3, 224, 224) input = tilt_dist(input) bn(input) chunked_forward(dbn, input) assert torch.allclose(bn.running_mean, dbn.running_mean, atol=1e-4) assert torch.allclose(bn.running_var, dbn.running_var, atol=1e-4) def test_convert_deferred_batch_norm(): bn = nn.BatchNorm2d(3, track_running_stats=False) bn = DeferredBatchNorm.convert_deferred_batch_norm(bn, chunks=CHUNKS) assert type(bn) is nn.BatchNorm2d # because of track_running_stats=False dbn = DeferredBatchNorm(3, chunks=CHUNKS) dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS) assert dbn is dbn_again dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS + 1) assert dbn is not dbn_again # because of different chunks def test_eval(): bn = nn.BatchNorm2d(3) dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) input = torch.rand(16, 3, 224, 224) input = tilt_dist(input) bn(input) chunked_forward(dbn, input) bn.eval() dbn.eval() assert torch.allclose(bn(input), dbn(input), atol=1e-4) def test_optimize(): bn = nn.BatchNorm2d(3) dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=1.0) for i in range(5): input = torch.rand(16, 3, 224, 224) input = tilt_dist(input) # train y = bn(input) a = y.sum() a.backward() y = chunked_forward(dbn, input) b = y.sum() b.backward() opt.step() # eval bn.eval() dbn.eval() with torch.no_grad(): assert torch.allclose(bn(input), dbn(input), atol=1e-1 * (10**i)) def test_conv_bn(): bn = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3)) dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) input = torch.rand(16, 3, 224, 224) input = tilt_dist(input) opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=0.1) # 1st step a = bn(input) b = chunked_forward(dbn, input) # Outputs are different. (per-mini-batch vs. per-micro-batch) assert not torch.allclose(a, b) a.sum().backward() b.sum().backward() opt.step() opt.zero_grad() # Conv layers are also trained differently because of their different outputs. assert not torch.allclose(bn[0].weight, dbn[0].weight) # But BNs track identical running stats. assert torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4) assert torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3) # 2nd step a = bn(input) b = chunked_forward(dbn, input) a.sum().backward() b.sum().backward() # BNs can't track identical running stats due to the different conv layers. assert not torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4) assert not torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3) def test_input_requiring_grad(): dbn = DeferredBatchNorm(3, chunks=CHUNKS) input = torch.rand(16, 3, 224, 224) input = tilt_dist(input) input.requires_grad = True chunked_forward(dbn, input) assert not dbn.sum.requires_grad assert dbn.sum.grad_fn is None