Commit 25985c31 authored by Hang Zhang's avatar Hang Zhang
Browse files

sync BN

parent d40adbc4
...@@ -5,16 +5,14 @@ ...@@ -5,16 +5,14 @@
## Copyright (c) 2017 ## Copyright (c) 2017
## ##
## This source code is licensed under the MIT-style license found in the ## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree ## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""Encoding Customized Functions"""
import math import math
import threading
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function, Variable from torch.autograd import Function, Variable
from torch.nn.modules.utils import _single, _pair, _triple from torch.nn.modules.utils import _pair
from .._ext import encoding_lib from .._ext import encoding_lib
...@@ -23,32 +21,31 @@ __all__ = ['dilatedavgpool2d'] ...@@ -23,32 +21,31 @@ __all__ = ['dilatedavgpool2d']
class _dilatedavgpool2d(Function): class _dilatedavgpool2d(Function):
@staticmethod @staticmethod
def forward(ctx, input, kernel_size, stride, padding, def forward(ctx, input, kernel_size, stride, padding,
dilation=1): dilation=1):
ctx.kH, ctx.kW = _pair(kernel_size) ctx.kH, ctx.kW = _pair(kernel_size)
ctx.dH, ctx.dW = _pair(stride if stride is not None else ctx.dH, ctx.dW = _pair(stride if stride is not None else kernel_size)
kernel_size)
ctx.padH, ctx.padW = _pair(padding) ctx.padH, ctx.padW = _pair(padding)
ctx.dilationH, ctx.dilationW = _pair(dilation) ctx.dilationH, ctx.dilationW = _pair(dilation)
b,c,h,w = input.size() b, c, h, w = input.size()
if ctx.dH==1 and ctx.dW==1: if ctx.dH == 1 and ctx.dW == 1:
# keep the size for dilated avgpool # keep the size for dilated avgpool
ow, oh = w, h ow, oh = w, h
else: else:
ow = math.floor(float(w-ctx.kW+2*ctx.padW)/float(ctx.dW)) +1 ow = math.floor(float(w-ctx.kW+2*ctx.padW)/float(ctx.dW)) +1
oh = math.floor(float(h-ctx.kH+2*ctx.padH)/float(ctx.dH)) +1 oh = math.floor(float(h-ctx.kH+2*ctx.padH)/float(ctx.dH)) +1
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
output = input.new(b,c,oh,ow) output = input.new(b, c, oh, ow)
ctx.save_for_backward(input) ctx.save_for_backward(input)
if isinstance(input, torch.cuda.FloatTensor): if isinstance(input, torch.cuda.FloatTensor):
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
encoding_lib.Encoding_Float_DilatedAvgPool2d_Forward(input, output, encoding_lib.Encoding_Float_DilatedAvgPool2d_Forward(
ctx.kH, ctx.kW, ctx.dH, ctx.dW, ctx.padH, ctx.padW, input, output, ctx.kH, ctx.kW, ctx.dH, ctx.dW, ctx.padH,
ctx.dilationH, ctx.dilationW) ctx.padW, ctx.dilationH, ctx.dilationW)
elif isinstance(input, torch.cuda.DoubleTensor): elif isinstance(input, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
encoding_lib.Encoding_Double_DilatedAvgPool2d_Forward(input, output, encoding_lib.Encoding_Double_DilatedAvgPool2d_Forward(
ctx.kH, ctx.kW, ctx.dH, ctx.dW, ctx.padH, ctx.padW, input, output, ctx.kH, ctx.kW, ctx.dH, ctx.dW, ctx.padH,
ctx.dilationH, ctx.dilationW) ctx.padW, ctx.dilationH, ctx.dilationW)
else: else:
raise RuntimeError('Unimplemented data type!') raise RuntimeError('Unimplemented data type!')
return output return output
...@@ -75,13 +72,14 @@ class _dilatedavgpool2d(Function): ...@@ -75,13 +72,14 @@ class _dilatedavgpool2d(Function):
return gradInput, None, None, None, None return gradInput, None, None, None, None
def dilatedavgpool2d(input, kernel_size, stride=None, padding=0, def dilatedavgpool2d(input, kernel_size, stride=None, padding=0,
dilation=1): dilation=1):
"""Dilated Average Pool 2d, for dilation of DenseNet. """Dilated Average Pool 2d, for dilation of DenseNet.
Reference: Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, Amit Agrawal. “Context Encoding for Semantic Segmentation. CVPR 2018 Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang,
Ambrish Tyagi, Amit Agrawal. “Context Encoding for Semantic Segmentation. CVPR 2018
Applies 2D average-pooling operation in kh x kw regions by step size Applies 2D average-pooling operation in kh x kw regions by step size
dh x dw steps. The number of output features is equal to the number of dh x dw steps. The number of output features is equal to the number of
...@@ -99,5 +97,4 @@ def dilatedavgpool2d(input, kernel_size, stride=None, padding=0, ...@@ -99,5 +97,4 @@ def dilatedavgpool2d(input, kernel_size, stride=None, padding=0,
a tuple (padh x padw), Default: 0 a tuple (padh x padw), Default: 0
dilation: the dilation parameter similar to Conv2d dilation: the dilation parameter similar to Conv2d
""" """
return _dilatedavgpool2d.apply(input, kernel_size, stride, padding, return _dilatedavgpool2d.apply(input, kernel_size, stride, padding, dilation)
dilation)
...@@ -5,13 +5,11 @@ ...@@ -5,13 +5,11 @@
## Copyright (c) 2017 ## Copyright (c) 2017
## ##
## This source code is licensed under the MIT-style license found in the ## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree ## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import threading """Functions for Encoding Layer"""
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function, Variable from torch.autograd import Function, Variable
from .._ext import encoding_lib from .._ext import encoding_lib
...@@ -19,13 +17,13 @@ __all__ = ['aggregate', 'scaledL2'] ...@@ -19,13 +17,13 @@ __all__ = ['aggregate', 'scaledL2']
class _aggregate(Function): class _aggregate(Function):
@staticmethod @staticmethod
def forward(self, A, X, C): def forward(ctx, A, X, C):
# A \in(BxNxK) R \in(BxNxKxD) => E \in(BxNxD) # A \in(BxNxK) R \in(BxNxKxD) => E \in(BxNxD)
self.save_for_backward(A, X, C) ctx.save_for_backward(A, X, C)
B, N, K = A.size() B, _, K = A.size()
D = X.size(2) D = X.size(2)
with torch.cuda.device_of(A): with torch.cuda.device_of(A):
E = A.new(B,K,D) E = A.new(B, K, D)
if isinstance(A, torch.cuda.FloatTensor): if isinstance(A, torch.cuda.FloatTensor):
with torch.cuda.device_of(A): with torch.cuda.device_of(A):
encoding_lib.Encoding_Float_aggregate_forward(E, A, X, C) encoding_lib.Encoding_Float_aggregate_forward(E, A, X, C)
...@@ -37,19 +35,19 @@ class _aggregate(Function): ...@@ -37,19 +35,19 @@ class _aggregate(Function):
return E return E
@staticmethod @staticmethod
def backward(self, gradE): def backward(ctx, gradE):
A, X, C = self.saved_variables A, X, C = ctx.saved_variables
with torch.cuda.device_of(A): with torch.cuda.device_of(A):
gradA = Variable(A.data.new().resize_as_(A.data)) gradA = Variable(A.data.new().resize_as_(A.data))
gradX = Variable(A.data.new().resize_as_(X.data)) gradX = Variable(A.data.new().resize_as_(X.data))
gradC = Variable(A.data.new().resize_as_(C.data)) gradC = Variable(A.data.new().resize_as_(C.data))
if isinstance(A.data, torch.cuda.FloatTensor): if isinstance(A.data, torch.cuda.FloatTensor):
with torch.cuda.device_of(A.data): with torch.cuda.device_of(A.data):
encoding_lib.Encoding_Float_aggregate_backward(gradA.data, encoding_lib.Encoding_Float_aggregate_backward(gradA.data, \
gradE.data, A.data, X.data, C.data) gradE.data, A.data, X.data, C.data)
elif isinstance(A.data, torch.cuda.DoubleTensor): elif isinstance(A.data, torch.cuda.DoubleTensor):
with torch.cuda.device_of(A.data): with torch.cuda.device_of(A.data):
encoding_lib.Encoding_Double_aggregate_backward(gradA.data, encoding_lib.Encoding_Double_aggregate_backward(gradA.data, \
gradE.data, A.data, X.data, C.data) gradE.data, A.data, X.data, C.data)
else: else:
raise RuntimeError('Unimplemented data type!') raise RuntimeError('Unimplemented data type!')
...@@ -59,14 +57,17 @@ class _aggregate(Function): ...@@ -59,14 +57,17 @@ class _aggregate(Function):
def aggregate(A, X, C): def aggregate(A, X, C):
r""" r"""
Aggregate operation, aggregate the residuals of inputs (:math:`X`) with repect to the codewords (:math:`C`) with assignment weights (:math:`A`). Aggregate operation, aggregate the residuals of inputs (:math:`X`) with repect
to the codewords (:math:`C`) with assignment weights (:math:`A`).
.. math:: .. math::
e_{k} = \sum_{i=1}^{N} a_{ik} (x_i - d_k) e_{k} = \sum_{i=1}^{N} a_{ik} (x_i - d_k)
Shape: Shape:
- Input: :math:`A\in\mathcal{R}^{B\times N\times K}` :math:`X\in\mathcal{R}^{B\times N\times D}` :math:`C\in\mathcal{R}^{K\times D}` (where :math:`B` is batch, :math:`N` is total number of features, :math:`K` is number is codewords, :math:`D` is feature dimensions.) - Input: :math:`A\in\mathcal{R}^{B\times N\times K}`
:math:`X\in\mathcal{R}^{B\times N\times D}` :math:`C\in\mathcal{R}^{K\times D}`
(where :math:`B` is batch, :math:`N` is total number of features,
:math:`K` is number is codewords, :math:`D` is feature dimensions.)
- Output: :math:`E\in\mathcal{R}^{B\times K\times D}` - Output: :math:`E\in\mathcal{R}^{B\times K\times D}`
Examples: Examples:
...@@ -82,11 +83,11 @@ def aggregate(A, X, C): ...@@ -82,11 +83,11 @@ def aggregate(A, X, C):
class _scaledL2(Function): class _scaledL2(Function):
@staticmethod @staticmethod
def forward(self, X, C, S): def forward(ctx, X, C, S):
B,N,D = X.size() B, N, _ = X.size()
K = C.size(0) K = C.size(0)
with torch.cuda.device_of(X): with torch.cuda.device_of(X):
SL = X.new(B,N,K) SL = X.new(B, N, K)
if isinstance(X, torch.cuda.FloatTensor): if isinstance(X, torch.cuda.FloatTensor):
with torch.cuda.device_of(X): with torch.cuda.device_of(X):
encoding_lib.Encoding_Float_scaledl2_forward(SL, X, C, S) encoding_lib.Encoding_Float_scaledl2_forward(SL, X, C, S)
...@@ -95,12 +96,12 @@ class _scaledL2(Function): ...@@ -95,12 +96,12 @@ class _scaledL2(Function):
encoding_lib.Encoding_Double_scaledl2_forward(SL, X, C, S) encoding_lib.Encoding_Double_scaledl2_forward(SL, X, C, S)
else: else:
raise RuntimeError('Unimplemented data type!') raise RuntimeError('Unimplemented data type!')
self.save_for_backward(X, C, S, SL) ctx.save_for_backward(X, C, S, SL)
return SL return SL
@staticmethod @staticmethod
def backward(self, gradSL): def backward(ctx, gradSL):
X, C, S, SL = self.saved_variables X, C, S, SL = ctx.saved_variables
K = C.size(0) K = C.size(0)
with torch.cuda.device_of(X.data): with torch.cuda.device_of(X.data):
gradX = Variable(X.data.new().resize_as_(X.data)) gradX = Variable(X.data.new().resize_as_(X.data))
...@@ -108,15 +109,15 @@ class _scaledL2(Function): ...@@ -108,15 +109,15 @@ class _scaledL2(Function):
gradS = Variable(X.data.new().resize_as_(S.data)) gradS = Variable(X.data.new().resize_as_(S.data))
if isinstance(X.data, torch.cuda.FloatTensor): if isinstance(X.data, torch.cuda.FloatTensor):
with torch.cuda.device_of(X.data): with torch.cuda.device_of(X.data):
encoding_lib.Encoding_Float_scaledl2_backward(gradSL.data, encoding_lib.Encoding_Float_scaledl2_backward(gradSL.data, \
gradX.data, gradC.data, X.data, C.data, S.data) gradX.data, gradC.data, X.data, C.data, S.data)
elif isinstance(X.data, torch.cuda.DoubleTensor): elif isinstance(X.data, torch.cuda.DoubleTensor):
with torch.cuda.device_of(X.data): with torch.cuda.device_of(X.data):
encoding_lib.Encoding_Double_scaledl2_backward(gradSL.data, encoding_lib.Encoding_Double_scaledl2_backward(gradSL.data, \
gradX.data, gradC.data, X.data, C.data, S.data) gradX.data, gradC.data, X.data, C.data, S.data)
else: else:
raise RuntimeError('Unimplemented data type!') raise RuntimeError('Unimplemented data type!')
gradS.data.copy_((gradSL*(SL/S.view(1,1,K))).sum(0).sum(0).data) gradS.data.copy_((gradSL*(SL/S.view(1, 1, K))).sum(0).sum(0).data)
return gradX, gradC, gradS return gradX, gradC, gradS
...@@ -128,10 +129,11 @@ def scaledL2(X, C, S): ...@@ -128,10 +129,11 @@ def scaledL2(X, C, S):
sl_{ik} = s_k \|x_i-c_k\|^2 sl_{ik} = s_k \|x_i-c_k\|^2
Shape: Shape:
- Input: :math:`X\in\mathcal{R}^{B\times N\times D}` :math:`C\in\mathcal{R}^{K\times D}` :math:`S\in \mathcal{R}^K` (where :math:`B` is batch, :math:`N` is total number of features, :math:`K` is number is codewords, :math:`D` is feature dimensions.) - Input: :math:`X\in\mathcal{R}^{B\times N\times D}`
:math:`C\in\mathcal{R}^{K\times D}` :math:`S\in \mathcal{R}^K`
(where :math:`B` is batch, :math:`N` is total number of features,
:math:`K` is number is codewords, :math:`D` is feature dimensions.)
- Output: :math:`E\in\mathcal{R}^{B\times N\times K}` - Output: :math:`E\in\mathcal{R}^{B\times N\times K}`
""" """
return _scaledL2.apply(X, C, S) return _scaledL2.apply(X, C, S)
...@@ -5,107 +5,107 @@ ...@@ -5,107 +5,107 @@
## Copyright (c) 2017 ## Copyright (c) 2017
## ##
## This source code is licensed under the MIT-style license found in the ## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree ## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import threading """Synchronized Batch Normalization functions"""
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function, Variable from torch.autograd import Function, Variable
from .._ext import encoding_lib from .._ext import encoding_lib
__all__ = ['sum_square', 'batchnormtrain', 'batchnormeval'] __all__ = ['sum_square', 'batchnormtrain', 'batchnormeval']
class _sum_square(Function): class _sum_square(Function):
@staticmethod
def forward(ctx, input): def forward(ctx, input):
ctx.save_for_backward(input) ctx.save_for_backward(input)
B,C,H,W = input.size() B, C, _, _ = input.size()
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
xsum = input.new().resize_(C).zero_() xsum = input.new().resize_(C).zero_()
xsquare = input.new().resize_(C).zero_() xsquare = input.new().resize_(C).zero_()
if isinstance(input, torch.cuda.FloatTensor): if isinstance(input, torch.cuda.FloatTensor):
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
encoding_lib.Encoding_Float_sum_square_Forward( encoding_lib.Encoding_Float_sum_square_Forward(
input.view(B,C,-1), xsum, xsquare) input.view(B, C, -1), xsum, xsquare)
elif isinstance(input, torch.cuda.DoubleTensor): elif isinstance(input, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
encoding_lib.Encoding_Double_sum_square_Forward( encoding_lib.Encoding_Double_sum_square_Forward(
input.view(B,C,-1), xsum, xsquare) input.view(B, C, -1), xsum, xsquare)
else: else:
raise RuntimeError('Unimplemented data type!') raise RuntimeError('Unimplemented data type!')
return xsum, xsquare return xsum, xsquare
@staticmethod
def backward(ctx, gradSum, gradSquare): def backward(ctx, gradSum, gradSquare):
input, = ctx.saved_tensors input, = ctx.saved_variables
B,C,H,W = input.size() B, C, H, W = input.data.size()
with torch.cuda.device_of(input): with torch.cuda.device_of(input.data):
gradInput = input.new().resize_(B,C,H*W).zero_() gradInput = Variable(input.data.new().resize_(B, C, H*W).zero_())
if isinstance(input, torch.cuda.FloatTensor): if isinstance(input.data, torch.cuda.FloatTensor):
with torch.cuda.device_of(input): with torch.cuda.device_of(input.data):
encoding_lib.Encoding_Float_sum_square_Backward( encoding_lib.Encoding_Float_sum_square_Backward(
gradInput, input.view(B,C,-1), gradSum, gradSquare) gradInput, input.data.view(B, C, -1), gradSum, gradSquare)
elif isinstance(input, torch.cuda.DoubleTensor): elif isinstance(input.data, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input): with torch.cuda.device_of(input.data):
encoding_lib.Encoding_Double_sum_square_Backward( encoding_lib.Encoding_Double_sum_square_Backward(
gradInput, input.view(B,C,-1), gradSum, gradSquare) gradInput, input.data.view(B, C, -1), gradSum, gradSquare)
else: else:
raise RuntimeError('Unimplemented data type!') raise RuntimeError('Unimplemented data type!')
return gradInput.view(B,C,H,W) return gradInput.view(B, C, H, W)
def sum_square(input): def sum_square(input):
r""" r"""
Calculate sum of elements and sum of squares for Batch Normalization. Calculate sum of elements and sum of squares for Batch Normalization.
""" """
return _sum_square()(input) return _sum_square.apply(input)
class _batchnorm(Function): class _batchnorm(Function):
def __init__(ctx, training=False): def __init__(self, training=False):
super(_batchnorm, ctx).__init__() super(_batchnorm, self).__init__()
ctx.training = training self.training = training
def forward(ctx, input, gamma, beta, mean, std): def forward(self, input, gamma, beta, mean, std):
ctx.save_for_backward(input, gamma, beta, mean, std) self.save_for_backward(input, gamma, beta, mean, std)
assert(input.dim()==3) assert(input.dim() == 3)
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
invstd = 1.0 / std invstd = 1.0 / std
output = input.new().resize_as_(input) output = input.new().resize_as_(input)
if isinstance(input, torch.cuda.FloatTensor): if isinstance(input, torch.cuda.FloatTensor):
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
encoding_lib.Encoding_Float_batchnorm_Forward(output, encoding_lib.Encoding_Float_batchnorm_Forward(output, \
input, mean, invstd, gamma, beta) input, mean, invstd, gamma, beta)
elif isinstance(input, torch.cuda.DoubleTensor): elif isinstance(input, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
encoding_lib.Encoding_Double_batchnorm_Forward(output, encoding_lib.Encoding_Double_batchnorm_Forward(output, \
input, mean, invstd, gamma, beta) input, mean, invstd, gamma, beta)
else: else:
raise RuntimeError('Unimplemented data type!') raise RuntimeError('Unimplemented data type!')
return output return output
def backward(ctx, gradOutput): def backward(self, gradOutput):
input, gamma, beta, mean, std = ctx.saved_tensors input, gamma, beta, mean, std = self.saved_tensors
invstd = 1.0 / std invstd = 1.0 / std
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
gradInput = gradOutput.new().resize_as_(input).zero_() gradInput = gradOutput.new().resize_as_(input).zero_()
gradGamma = gradOutput.new().resize_as_(gamma).zero_() gradGamma = gradOutput.new().resize_as_(gamma).zero_()
gradBeta = gradOutput.new().resize_as_(beta).zero_() gradBeta = gradOutput.new().resize_as_(beta).zero_()
gradMean = gradOutput.new().resize_as_(mean).zero_() gradMean = gradOutput.new().resize_as_(mean).zero_()
gradStd = gradOutput.new().resize_as_(std).zero_() gradStd = gradOutput.new().resize_as_(std).zero_()
if isinstance(input, torch.cuda.FloatTensor): if isinstance(input, torch.cuda.FloatTensor):
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
encoding_lib.Encoding_Float_batchnorm_Backward( encoding_lib.Encoding_Float_batchnorm_Backward(
gradOutput, input, gradInput, gradGamma, gradBeta, gradOutput, input, gradInput, gradGamma, gradBeta,
mean, invstd, gamma, beta, gradMean, gradStd, mean, invstd, gamma, beta, gradMean, gradStd,
ctx.training) self.training)
elif isinstance(input, torch.cuda.DoubleTensor): elif isinstance(input, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
encoding_lib.Encoding_Double_batchnorm_Backward( encoding_lib.Encoding_Double_batchnorm_Backward(
gradOutput, input, gradInput, gradGamma, gradBeta, gradOutput, input, gradInput, gradGamma, gradBeta,
mean, invstd, gamma, beta, gradMean, gradStd, mean, invstd, gamma, beta, gradMean, gradStd,
ctx.training) self.training)
else: else:
raise RuntimeError('Unimplemented data type!') raise RuntimeError('Unimplemented data type!')
return gradInput, gradGamma, gradBeta, gradMean, gradStd return gradInput, gradGamma, gradBeta, gradMean, gradStd
......
...@@ -5,10 +5,10 @@ ...@@ -5,10 +5,10 @@
## Copyright (c) 2017 ## Copyright (c) 2017
## ##
## This source code is licensed under the MIT-style license found in the ## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree ## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""Encoding NN Modules"""
from .encoding import * from .encoding import *
from .syncbn import * from .syncbn import *
from .basic import *
from .customize import * from .customize import *
This diff is collapsed.
...@@ -5,19 +5,15 @@ ...@@ -5,19 +5,15 @@
## Copyright (c) 2017 ## Copyright (c) 2017
## ##
## This source code is licensed under the MIT-style license found in the ## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree ## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import math """Encoding Custermized NN Module"""
import torch import torch
from torch.autograd import Variable from torch.nn import Module, Sequential, Conv2d, ReLU, AdaptiveAvgPool2d
from torch.nn import Module, Parameter
from torch.nn import functional as F from torch.nn import functional as F
from ..parallel import my_data_parallel
from .syncbn import BatchNorm2d from .syncbn import BatchNorm2d
from ..functions import view_each, upsample
from .basic import *
__all__ = ['GramMatrix', 'View', 'Sum', 'Mean', 'Normalize', 'PyramidPooling'] __all__ = ['GramMatrix', 'View', 'Sum', 'Mean', 'Normalize', 'PyramidPooling']
...@@ -48,12 +44,7 @@ class View(Module): ...@@ -48,12 +44,7 @@ class View(Module):
self.size = torch.Size(args) self.size = torch.Size(args)
def forward(self, input): def forward(self, input):
if isinstance(input, Variable): return input.view(self.size)
return input.view(self.size)
elif isinstance(input, tuple) or isinstance(input, list):
return view_each(input, self.size)
else:
raise RuntimeError('unknown input type')
class Sum(Module): class Sum(Module):
...@@ -63,12 +54,7 @@ class Sum(Module): ...@@ -63,12 +54,7 @@ class Sum(Module):
self.keep_dim = keep_dim self.keep_dim = keep_dim
def forward(self, input): def forward(self, input):
if isinstance(input, Variable): return input.sum(self.dim, self.keep_dim)
return input.sum(self.dim, self.keep_dim)
elif isinstance(input, tuple) or isinstance(input, list):
return my_data_parallel(self, input)
else:
raise RuntimeError('unknown input type')
class Mean(Module): class Mean(Module):
...@@ -78,12 +64,7 @@ class Mean(Module): ...@@ -78,12 +64,7 @@ class Mean(Module):
self.keep_dim = keep_dim self.keep_dim = keep_dim
def forward(self, input): def forward(self, input):
if isinstance(input, Variable): return input.mean(self.dim, self.keep_dim)
return input.mean(self.dim, self.keep_dim)
elif isinstance(input, tuple) or isinstance(input, list):
return my_data_parallel(self, input)
else:
raise RuntimeError('unknown input type')
class Normalize(Module): class Normalize(Module):
...@@ -108,20 +89,15 @@ class Normalize(Module): ...@@ -108,20 +89,15 @@ class Normalize(Module):
def __init__(self, p=2, dim=1): def __init__(self, p=2, dim=1):
super(Normalize, self).__init__() super(Normalize, self).__init__()
self.p = p self.p = p
self.dim =dim self.dim = dim
def forward(self, x): def forward(self, x):
if isinstance(x, Variable): return F.normalize(x, self.p, self.dim, eps=1e-10)
return F.normalize(x, self.p, self.dim, eps=1e-10)
elif isinstance(x, tuple) or isinstance(x, list):
return my_data_parallel(self, x)
else:
raise RuntimeError('unknown input type')
class PyramidPooling(Module): class PyramidPooling(Module):
""" """
Reference: Reference:
Zhao, Hengshuang, et al. *"Pyramid scene parsing network."* Zhao, Hengshuang, et al. *"Pyramid scene parsing network."*
""" """
def __init__(self, in_channels): def __init__(self, in_channels):
...@@ -146,31 +122,16 @@ class PyramidPooling(Module): ...@@ -146,31 +122,16 @@ class PyramidPooling(Module):
ReLU(True)) ReLU(True))
def _cat_each(self, x, feat1, feat2, feat3, feat4): def _cat_each(self, x, feat1, feat2, feat3, feat4):
assert(len(x)==len(feat1)) assert(len(x) == len(feat1))
z = [] z = []
for i in range(len(x)): for i in range(len(x)):
z.append( torch.cat((x[i], feat1[i], feat2[i], feat3[i], feat4[i]), 1)) z.append(torch.cat((x[i], feat1[i], feat2[i], feat3[i], feat4[i]), 1))
return z return z
def forward(self, x): def forward(self, x):
if isinstance(x, Variable): _, _, h, w = x.size()
_, _, h, w = x.size() feat1 = F.upsample(self.conv1(self.pool1(x)), (h, w), mode='bilinear')
elif isinstance(x, tuple) or isinstance(x, list): feat2 = F.upsample(self.conv2(self.pool2(x)), (h, w), mode='bilinear')
_, _, h, w = x[0].size() feat3 = F.upsample(self.conv3(self.pool3(x)), (h, w), mode='bilinear')
else: feat4 = F.upsample(self.conv4(self.pool4(x)), (h, w), mode='bilinear')
raise RuntimeError('unknown input type') return torch.cat((x, feat1, feat2, feat3, feat4), 1)
feat1 = upsample(self.conv1(self.pool1(x)),(h,w),
mode='bilinear')
feat2 = upsample(self.conv2(self.pool2(x)),(h,w),
mode='bilinear')
feat3 = upsample(self.conv3(self.pool3(x)),(h,w),
mode='bilinear')
feat4 = upsample(self.conv4(self.pool4(x)),(h,w),
mode='bilinear')
if isinstance(x, Variable):
return torch.cat((x, feat1, feat2, feat3, feat4), 1)
elif isinstance(x, tuple) or isinstance(x, list):
return self._cat_each(x, feat1, feat2, feat3, feat4)
else:
raise RuntimeError('unknown input type')
...@@ -5,53 +5,65 @@ ...@@ -5,53 +5,65 @@
## Copyright (c) 2017 ## Copyright (c) 2017
## ##
## This source code is licensed under the MIT-style license found in the ## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree ## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import threading """Encoding Package Core NN Modules."""
import torch import torch
from torch.nn import Module, Parameter from torch.nn import Module, Parameter
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Function, Variable from torch.autograd import Variable
from torch.nn.modules.utils import _single, _pair, _triple from torch.nn.modules.utils import _pair
from .._ext import encoding_lib from ..functions import scaledL2, aggregate, dilatedavgpool2d
from ..functions import scaledL2, aggregate
from ..parallel import my_data_parallel
from ..functions import dilatedavgpool2d
__all__ = ['Encoding', 'EncodingDrop', 'Inspiration', 'DilatedAvgPool2d', 'UpsampleConv2d'] __all__ = ['Encoding', 'EncodingDrop', 'Inspiration', 'DilatedAvgPool2d', 'UpsampleConv2d']
class Encoding(Module): class Encoding(Module):
r""" r"""
Encoding Layer: a learnable residual encoder over 3d or 4d input that Encoding Layer: a learnable residual encoder.
is seen as a mini-batch.
.. image:: _static/img/cvpr17.svg .. image:: _static/img/cvpr17.svg
:width: 50% :width: 50%
:align: center :align: center
.. math:: Encoding Layer accpets 3D or 4D inputs.
It considers an input featuremaps with the shape of :math:`C\times H\times W`
as a set of C-dimentional input features :math:`X=\{x_1, ...x_N\}`, where N is total number
of features given by :math:`H\times W`, which learns an inherent codebook
:math:`D=\{d_1,...d_K\}` and a set of smoothing factor of visual centers
:math:`S=\{s_1,...s_K\}`. Encoding Layer outputs the residuals with soft-assignment weights
:math:`e_k=\sum_{i=1}^Ne_{ik}`, where
e_{ik} = \frac{exp(-s_k\|x_{i}-c_k\|^2)}{\sum_{j=1}^K exp(-s_j\|x_{i}-c_j\|^2)} (x_i - c_k) .. math::
Please see the `example of training Deep TEN <./experiments/texture.html>`_. e_{ik} = \frac{exp(-s_k\|r_{ik}\|^2)}{\sum_{j=1}^K exp(-s_j\|r_{ij}\|^2)} r_{ik}
Reference: and the residuals are given by :math:`r_{ik} = x_i - d_k`. The output encoders are
Hang Zhang, Jia Xue, and Kristin Dana. "Deep TEN: Texture Encoding Network." *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2017* :math:`E=\{e_1,...e_K\}`.
Args: Args:
D: dimention of the features or feature channels D: dimention of the features or feature channels
K: number of codeswords K: number of codeswords
Shape: Shape:
- Input: :math:`X\in\mathcal{R}^{B\times N\times D}` or :math:`\mathcal{R}^{B\times D\times H\times W}` (where :math:`B` is batch, :math:`N` is total number of features or :math:`H\times W`.) - Input: :math:`X\in\mathcal{R}^{B\times N\times D}` or
:math:`\mathcal{R}^{B\times D\times H\times W}` (where :math:`B` is batch,
:math:`N` is total number of features or :math:`H\times W`.)
- Output: :math:`E\in\mathcal{R}^{B\times K\times D}` - Output: :math:`E\in\mathcal{R}^{B\times K\times D}`
Attributes: Attributes:
codewords (Tensor): the learnable codewords of shape (:math:`K\times D`) codewords (Tensor): the learnable codewords of shape (:math:`K\times D`)
scale (Tensor): the learnable scale factor of visual centers scale (Tensor): the learnable scale factor of visual centers
Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
Amit Agrawal. “Context Encoding for Semantic Segmentation.
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
Hang Zhang, Jia Xue, and Kristin Dana. "Deep TEN: Texture Encoding Network."
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2017*
Examples: Examples:
>>> import encoding >>> import encoding
>>> import torch >>> import torch
...@@ -66,32 +78,26 @@ class Encoding(Module): ...@@ -66,32 +78,26 @@ class Encoding(Module):
super(Encoding, self).__init__() super(Encoding, self).__init__()
# init codewords and smoothing factor # init codewords and smoothing factor
self.D, self.K = D, K self.D, self.K = D, K
self.codewords = Parameter(torch.Tensor(K, D), self.codewords = Parameter(torch.Tensor(K, D), requires_grad=True)
requires_grad=True) self.scale = Parameter(torch.Tensor(K), requires_grad=True)
self.scale = Parameter(torch.Tensor(K), requires_grad=True)
self.reset_params() self.reset_params()
def reset_params(self): def reset_params(self):
std1 = 1./((self.K*self.D)**(1/2)) std1 = 1./((self.K*self.D)**(1/2))
self.codewords.data.uniform_(-std1, std1) self.codewords.data.uniform_(-std1, std1)
self.scale.data.uniform_(-1, 0) self.scale.data.uniform_(-1, 0)
def forward(self, X): def forward(self, X):
if isinstance(X, tuple) or isinstance(X, list):
# for self-parallel mode, please see encoding.nn
return my_data_parallel(self, X)
elif not isinstance(X, Variable):
raise RuntimeError('unknown input type')
# input X is a 4D tensor # input X is a 4D tensor
assert(X.size(1)==self.D) assert(X.size(1) == self.D)
if X.dim() == 3: if X.dim() == 3:
# BxDxN # BxDxN
B, N, K, D = X.size(0), X.size(2), self.K, self.D B, D = X.size(0), self.D
X = X.transpose(1,2).contiguous() X = X.transpose(1, 2).contiguous()
elif X.dim() == 4: elif X.dim() == 4:
# BxDxHxW # BxDxHxW
B, N, K, D = X.size(0), X.size(2)*X.size(3), self.K, self.D B, D = X.size(0), self.D
X = X.view(B,D,-1).transpose(1,2).contiguous() X = X.view(B, D, -1).transpose(1, 2).contiguous()
else: else:
raise RuntimeError('Encoding Layer unknown input dims!') raise RuntimeError('Encoding Layer unknown input dims!')
# assignment weights NxKxD # assignment weights NxKxD
...@@ -106,15 +112,16 @@ class Encoding(Module): ...@@ -106,15 +112,16 @@ class Encoding(Module):
+ str(self.D) + ')' + str(self.D) + ')'
class EncodingDrop(Module): class EncodingDrop(Module):
r"""Dropout regularized Encoding Layer.
"""
def __init__(self, D, K): def __init__(self, D, K):
super(EncodingDrop, self).__init__() super(EncodingDrop, self).__init__()
# init codewords and smoothing factor # init codewords and smoothing factor
self.D, self.K = D, K self.D, self.K = D, K
self.codewords = Parameter(torch.Tensor(K, D), self.codewords = Parameter(torch.Tensor(K, D), requires_grad=True)
requires_grad=True) self.scale = Parameter(torch.Tensor(K), requires_grad=True)
self.scale = Parameter(torch.Tensor(K), requires_grad=True)
self.reset_params() self.reset_params()
def reset_params(self): def reset_params(self):
std1 = 1./((self.K*self.D)**(1/2)) std1 = 1./((self.K*self.D)**(1/2))
self.codewords.data.uniform_(-std1, std1) self.codewords.data.uniform_(-std1, std1)
...@@ -127,21 +134,16 @@ class EncodingDrop(Module): ...@@ -127,21 +134,16 @@ class EncodingDrop(Module):
self.scale.data.zero_().add_(-0.5) self.scale.data.zero_().add_(-0.5)
def forward(self, X): def forward(self, X):
if isinstance(X, tuple) or isinstance(X, list):
# for self-parallel mode, please see encoding.nn
return my_data_parallel(self, X)
elif not isinstance(X, Variable):
raise RuntimeError('unknown input type')
# input X is a 4D tensor # input X is a 4D tensor
assert(X.size(1)==self.D) assert(X.size(1) == self.D)
if X.dim() == 3: if X.dim() == 3:
# BxDxN # BxDxN
B, N, K, D = X.size(0), X.size(2), self.K, self.D B, D = X.size(0), self.D
X = X.transpose(1,2).contiguous() X = X.transpose(1, 2).contiguous()
elif X.dim() == 4: elif X.dim() == 4:
# BxDxHxW # BxDxHxW
B, N, K, D = X.size(0), X.size(2)*X.size(3), self.K, self.D B, D = X.size(0), self.D
X = X.view(B,D,-1).transpose(1,2).contiguous() X = X.view(B, D, -1).transpose(1, 2).contiguous()
else: else:
raise RuntimeError('Encoding Layer unknown input dims!') raise RuntimeError('Encoding Layer unknown input dims!')
self._drop() self._drop()
...@@ -159,25 +161,28 @@ class EncodingDrop(Module): ...@@ -159,25 +161,28 @@ class EncodingDrop(Module):
class Inspiration(Module): class Inspiration(Module):
r""" r"""
Inspiration Layer (CoMatch Layer) enables the multi-style transfer in feed-forward network, which learns to match the target feature statistics during the training. Inspiration Layer (CoMatch Layer) enables the multi-style transfer in feed-forward
This module is differentialble and can be inserted in standard feed-forward network to be learned directly from the loss function without additional supervision. network, which learns to match the target feature statistics during the training.
This module is differentialble and can be inserted in standard feed-forward network
to be learned directly from the loss function without additional supervision.
.. math:: .. math::
Y = \phi^{-1}[\phi(\mathcal{F}^T)W\mathcal{G}] Y = \phi^{-1}[\phi(\mathcal{F}^T)W\mathcal{G}]
Please see the `example of MSG-Net <./experiments/style.html>`_ Please see the `example of MSG-Net <./experiments/style.html>`_
training multi-style generative network for real-time transfer. training multi-style generative network for real-time transfer.
Reference: Reference:
Hang Zhang and Kristin Dana. "Multi-style Generative Network for Real-time Transfer." *arXiv preprint arXiv:1703.06953 (2017)* Hang Zhang and Kristin Dana. "Multi-style Generative Network for Real-time Transfer."
*arXiv preprint arXiv:1703.06953 (2017)*
""" """
def __init__(self, C, B=1): def __init__(self, C, B=1):
super(Inspiration, self).__init__() super(Inspiration, self).__init__()
# B is equal to 1 or input mini_batch # B is equal to 1 or input mini_batch
self.weight = Parameter(torch.Tensor(1,C,C), requires_grad=True) self.weight = Parameter(torch.Tensor(1, C, C), requires_grad=True)
# non-parameter buffer # non-parameter buffer
self.G = Variable(torch.Tensor(B,C,C), requires_grad=True) self.G = Variable(torch.Tensor(B, C, C), requires_grad=True)
self.C = C self.C = C
self.reset_parameters() self.reset_parameters()
...@@ -189,8 +194,9 @@ class Inspiration(Module): ...@@ -189,8 +194,9 @@ class Inspiration(Module):
def forward(self, X): def forward(self, X):
# input X is a 3D feature map # input X is a 3D feature map
self.P = torch.bmm(self.weight.expand_as(self.G),self.G) self.P = torch.bmm(self.weight.expand_as(self.G), self.G)
return torch.bmm(self.P.transpose(1,2).expand(X.size(0), self.C, self.C), X.view(X.size(0),X.size(1),-1)).view_as(X) return torch.bmm(self.P.transpose(1, 2).expand(X.size(0), self.C, self.C),
X.view(X.size(0), X.size(1), -1)).view_as(X)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(' \ return self.__class__.__name__ + '(' \
...@@ -203,18 +209,21 @@ class DilatedAvgPool2d(Module): ...@@ -203,18 +209,21 @@ class DilatedAvgPool2d(Module):
Reference: Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, Amit Agrawal. “Context Encoding for Semantic Segmentation. CVPR 2018 Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
Amit Agrawal. “Context Encoding for Semantic Segmentation.
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
Applies a 2D average pooling over an input signal composed of several input planes. Applies a 2D average pooling over an input signal composed of several input planes.
In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`, In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`,
output :math:`(B, C, H_{out}, W_{out})`, :attr:`kernel_size` :math:`(k_H,k_W)`, :attr:`stride` :math:`(s_H,s_W)` :attr:`dilation` :math:`(d_H,d_W)` output :math:`(B, C, H_{out}, W_{out})`, :attr:`kernel_size` :math:`(k_H,k_W)`,
:attr:`stride` :math:`(s_H,s_W)` :attr:`dilation` :math:`(d_H,d_W)`
can be precisely described as: can be precisely described as:
.. math:: .. math::
\begin{array}{ll} \begin{array}{ll}
out(b, c, h, w) = 1 / (k_H \cdot k_W) \cdot out(b, c, h, w) = 1 / (k_H \cdot k_W) \cdot
\sum_{{m}=0}^{k_H-1} \sum_{{n}=0}^{k_W-1} \sum_{{m}=0}^{k_H-1} \sum_{{n}=0}^{k_W-1}
input(b, c, s_H \cdot h + d_H \cdot m, s_W \cdot w + d_W \cdot n) input(b, c, s_H \cdot h + d_H \cdot m, s_W \cdot w + d_W \cdot n)
\end{array} \end{array}
...@@ -222,11 +231,13 @@ class DilatedAvgPool2d(Module): ...@@ -222,11 +231,13 @@ class DilatedAvgPool2d(Module):
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides | If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
for :attr:`padding` number of points for :attr:`padding` number of points
| The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: | The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`,
:attr:`dilation` can either be:
- a single ``int`` -- in which case the same value is used for the height and width dimension - a single ``int`` -- in which case the same value is used for the height
- a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, and width dimension
and the second `int` for the width dimension - a ``tuple`` of two ints -- in which case, the first `int` is used for
the height dimension, and the second `int` for the width dimension
Args: Args:
kernel_size: the size of the window kernel_size: the size of the window
...@@ -257,13 +268,8 @@ class DilatedAvgPool2d(Module): ...@@ -257,13 +268,8 @@ class DilatedAvgPool2d(Module):
self.dilation = dilation self.dilation = dilation
def forward(self, input): def forward(self, input):
if isinstance(input, Variable): return dilatedavgpool2d(input, self.kernel_size, self.stride,
return dilatedavgpool2d(input, self.kernel_size, self.stride,
self.padding, self.dilation) self.padding, self.dilation)
elif isinstance(input, tuple) or isinstance(input, list):
return my_data_parallel(self, input)
else:
raise RuntimeError('unknown input type')
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + ' (' \ return self.__class__.__name__ + ' (' \
...@@ -275,14 +281,17 @@ class DilatedAvgPool2d(Module): ...@@ -275,14 +281,17 @@ class DilatedAvgPool2d(Module):
class UpsampleConv2d(Module): class UpsampleConv2d(Module):
r""" r"""
To avoid the checkerboard artifacts of standard Fractionally-strided Convolution, we adapt an integer stride convolution but producing a :math:`2\times 2` outputs for each convolutional window. To avoid the checkerboard artifacts of standard Fractionally-strided Convolution,
we adapt an integer stride convolution but producing a :math:`2\times 2` outputs for
each convolutional window.
.. image:: _static/img/upconv.png .. image:: _static/img/upconv.png
:width: 50% :width: 50%
:align: center :align: center
Reference: Reference:
Hang Zhang and Kristin Dana. "Multi-style Generative Network for Real-time Transfer." *arXiv preprint arXiv:1703.06953 (2017)* Hang Zhang and Kristin Dana. "Multi-style Generative Network for Real-time Transfer."
*arXiv preprint arXiv:1703.06953 (2017)*
Args: Args:
in_channels (int): Number of channels in the input image in_channels (int): Number of channels in the input image
...@@ -290,8 +299,10 @@ class UpsampleConv2d(Module): ...@@ -290,8 +299,10 @@ class UpsampleConv2d(Module):
kernel_size (int or tuple): Size of the convolving kernel kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1 stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
output_padding (int or tuple, optional): Zero-padding added to one side of the output. Default: 0 output_padding (int or tuple, optional): Zero-padding added to one side of the output.
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 Default: 0
groups (int, optional): Number of blocked connections from input channels to output
channels. Default: 1
bias (bool, optional): If True, adds a learnable bias to the output. Default: True bias (bool, optional): If True, adds a learnable bias to the output. Default: True
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
scale_factor (int): scaling factor for upsampling convolution. Default: 1 scale_factor (int): scaling factor for upsampling convolution. Default: 1
...@@ -327,7 +338,7 @@ class UpsampleConv2d(Module): ...@@ -327,7 +338,7 @@ class UpsampleConv2d(Module):
""" """
def __init__(self, in_channels, out_channels, kernel_size, stride=1, def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, scale_factor =1, padding=0, dilation=1, groups=1, scale_factor=1,
bias=True): bias=True):
super(UpsampleConv2d, self).__init__() super(UpsampleConv2d, self).__init__()
kernel_size = _pair(kernel_size) kernel_size = _pair(kernel_size)
...@@ -347,11 +358,11 @@ class UpsampleConv2d(Module): ...@@ -347,11 +358,11 @@ class UpsampleConv2d(Module):
self.groups = groups self.groups = groups
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.weight = Parameter(torch.Tensor( self.weight = Parameter(torch.Tensor(
out_channels * scale_factor * scale_factor, out_channels * scale_factor * scale_factor,
in_channels // groups, *kernel_size)) in_channels // groups, *kernel_size))
if bias: if bias:
self.bias = Parameter(torch.Tensor(out_channels * self.bias = Parameter(torch.Tensor(
scale_factor * scale_factor)) out_channels * scale_factor * scale_factor))
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.reset_parameters() self.reset_parameters()
...@@ -366,12 +377,6 @@ class UpsampleConv2d(Module): ...@@ -366,12 +377,6 @@ class UpsampleConv2d(Module):
self.bias.data.uniform_(-stdv, stdv) self.bias.data.uniform_(-stdv, stdv)
def forward(self, input): def forward(self, input):
if isinstance(input, Variable): out = F.conv2d(input, self.weight, self.bias, self.stride,
out = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
self.padding, self.dilation, self.groups) return F.pixel_shuffle(out, self.scale_factor)
return F.pixel_shuffle(out, self.scale_factor)
elif isinstance(input, tuple) or isinstance(input, list):
return my_data_parallel(self, input)
else:
raise RuntimeError('unknown input type')
This diff is collapsed.
...@@ -5,113 +5,42 @@ ...@@ -5,113 +5,42 @@
## Copyright (c) 2017 ## Copyright (c) 2017
## ##
## This source code is licensed under the MIT-style license found in the ## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree ## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""Encoding Data Parallel"""
import threading import threading
import torch import torch
import torch.cuda.nccl as nccl from torch.autograd import Variable
import torch.cuda.comm as comm
from torch.autograd import Variable, Function
from torch.nn.modules import Module from torch.nn.modules import Module
from torch.nn.parallel.scatter_gather import scatter, scatter_kwargs, \ from torch.nn.parallel.scatter_gather import scatter_kwargs
gather
from torch.nn.parallel.replicate import replicate from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.parallel_apply import parallel_apply from torch.nn.parallel.parallel_apply import parallel_apply
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
__all__ = ['Reduce', 'AllReduce', 'Broadcast', 'ModelDataParallel', __all__ = ['allreduce', 'ModelDataParallel', 'CriterionDataParallel']
'CriterionDataParallel', 'SelfDataParallel']
def nccl_all_reduce(inputs):
# TODO, figure out why nccl all_reduce doesn't work for gradcheck
input_size = inputs[0].size()
#if nccl.is_available(inputs):
for i, inp in enumerate(inputs):
assert inp.is_cuda, \
"reduce_add expects all inputs to be on GPUs"
if inp.size() != input_size:
got = 'x'.join(str(x) for x in inp.size())
expected = 'x'.join(str(x) for x in input_size)
raise ValueError("input {} has invalid size: got {}, \
but expected {}".format(i, got, expected))
nccl.all_reduce(inputs)
return inputs
def comm_all_reduce(inputs):
# comm backend
result = comm.reduce_add(inputs)
results = []
for i in range(len(inputs)):
results.append(result.clone().cuda(i))
return results
class Reduce(Function):
def forward(ctx, *inputs):
ctx.save_for_backward(*inputs)
if len(inputs) == 1:
return inputs[0]
return comm.reduce_add(inputs)
def backward(ctx, gradOutput):
inputs = tuple(ctx.saved_tensors)
if len(inputs) == 1:
return gradOutput
gradInputs = []
for i in range(len(inputs)):
with torch.cuda.device_of(inputs[i]):
gradInputs.append(gradOutput.cuda())
return tuple(gradInputs)
class AllReduce(Function):
"""Cross GPU all reduce autograd operation for calculate mean and
variance in SyncBN.
"""
def forward(ctx, *inputs):
outputs = comm_all_reduce(list(inputs))
return tuple(outputs)
def backward(ctx, *gradOutputs):
gradInputs = comm_all_reduce(list(gradOutputs))
return tuple(gradInputs)
class Broadcast(Function): def allreduce(*inputs):
"""Multi-GPU broadcast autograd function """Cross GPU all reduce autograd operation for calculate mean and
variance in SyncBN.
""" """
def __init__(self, target_gpus): target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
super(Broadcast, self).__init__() result = ReduceAddCoalesced.apply(target_gpus[0], 1, *inputs)
self.target_gpus = target_gpus outputs = Broadcast.apply(target_gpus, *result)
assert len(outputs) == len(inputs)
def forward(self, *inputs): return outputs
if not all(input.is_cuda for input in inputs):
raise TypeError('Broadcast function not implemented for CPU tensors')
if len(inputs) == 0:
return tuple()
self.num_inputs = len(inputs)
self.input_device = inputs[0].get_device()
outputs = comm.broadcast_coalesced(inputs, self.target_gpus)
return tuple([t for tensors in outputs for t in tensors])
def backward(self, *grad_outputs):
grad_outputs = [grad_outputs[i:i + self.num_inputs]
for i in range(0, len(grad_outputs), self.num_inputs)]
return comm.reduce_add_coalesced(grad_outputs, self.input_device)
class ModelDataParallel(Module): class ModelDataParallel(Module):
"""Implements data parallelism at the module level. """Implements data parallelism at the module level.
Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, Amit Agrawal. “Context Encoding for Semantic Segmentation. CVPR 2018
This container parallelizes the application of the given module by This container parallelizes the application of the given module by
splitting the input across the specified devices by chunking in the splitting the input across the specified devices by chunking in the
batch dimension. batch dimension.
In the forward pass, the module is replicated on each device, In the forward pass, the module is replicated on each device,
and each replica handles a portion of the input. During the backwards and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
pass, gradients from each replica are summed into the original module. Note that the outputs are not gathered, please use compatible
Note that the outputs are not gathered, please use compatible
:class:`encoding.parallel.CriterionDataParallel`. :class:`encoding.parallel.CriterionDataParallel`.
The batch size should be larger than the number of GPUs used. It should The batch size should be larger than the number of GPUs used. It should
...@@ -122,10 +51,15 @@ class ModelDataParallel(Module): ...@@ -122,10 +51,15 @@ class ModelDataParallel(Module):
module: module to be parallelized module: module to be parallelized
device_ids: CUDA devices (default: all devices) device_ids: CUDA devices (default: all devices)
Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
Amit Agrawal. “Context Encoding for Semantic Segmentation.
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
Example:: Example::
>>> net = encoding.nn.ModelDataParallel(model, device_ids=[0, 1, 2]) >>> net = encoding.nn.ModelDataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var) >>> y = net(x)
""" """
def __init__(self, module, device_ids=None, output_device=None, dim=0): def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(ModelDataParallel, self).__init__() super(ModelDataParallel, self).__init__()
...@@ -140,13 +74,6 @@ class ModelDataParallel(Module): ...@@ -140,13 +74,6 @@ class ModelDataParallel(Module):
self.master_mean, self.master_var = {}, {} self.master_mean, self.master_var = {}, {}
if len(self.device_ids) == 1: if len(self.device_ids) == 1:
self.module.cuda(device_ids[0]) self.module.cuda(device_ids[0])
"""
# TODO FIXME temporal solution for BN
for m in self.module.modules():
classname = m.__class__.__name__
if classname.find('BatchNorm2d') != -1:
m.momentum = 0.9996
"""
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
...@@ -155,7 +82,7 @@ class ModelDataParallel(Module): ...@@ -155,7 +82,7 @@ class ModelDataParallel(Module):
replicas = self.replicate(self.module, \ replicas = self.replicate(self.module, \
self.device_ids[:len(inputs)]) self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs) outputs = self.parallel_apply(replicas, inputs, kwargs)
return outputs return outputs
def replicate(self, module, device_ids): def replicate(self, module, device_ids):
return replicate(module, device_ids) return replicate(module, device_ids)
...@@ -166,18 +93,26 @@ class ModelDataParallel(Module): ...@@ -166,18 +93,26 @@ class ModelDataParallel(Module):
def parallel_apply(self, replicas, inputs, kwargs): def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs) return parallel_apply(replicas, inputs, kwargs)
class CriterionDataParallel(Module): class CriterionDataParallel(Module):
""" """
Calculate loss in multiple-GPUs, which balance the memory usage for Calculate loss in multiple-GPUs, which balance the memory usage for
Semantic Segmentation. Semantic Segmentation.
Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, Amit Agrawal. “Context Encoding for Semantic Segmentation. CVPR 2018
The targets are splitted across the specified devices by chunking in The targets are splitted across the specified devices by chunking in
the batch dimension. Please use together with :class:`encoding.parallel.ModelDataParallel`. the batch dimension. Please use together with :class:`encoding.parallel.ModelDataParallel`.
Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
Amit Agrawal. “Context Encoding for Semantic Segmentation.
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
Example::
>>> net = encoding.nn.ModelDataParallel(model, device_ids=[0, 1, 2])
>>> criterion = encoding.nn.CriterionDataParallel(criterion, device_ids=[0, 1, 2])
>>> y = net(x)
>>> loss = criterion(y, target)
""" """
def __init__(self, module, device_ids=None, output_device=None, dim=0): def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(CriterionDataParallel, self).__init__() super(CriterionDataParallel, self).__init__()
...@@ -200,7 +135,7 @@ class CriterionDataParallel(Module): ...@@ -200,7 +135,7 @@ class CriterionDataParallel(Module):
return self.module(inputs, *targets[0], **kwargs[0]) return self.module(inputs, *targets[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, targets, kwargs) outputs = self.parallel_apply(replicas, inputs, targets, kwargs)
return self.gather(outputs, self.output_device) return ReduceAddCoalesced.apply(self.output_device, 1, *outputs) / len(outputs)
def replicate(self, module, device_ids): def replicate(self, module, device_ids):
return replicate(module, device_ids) return replicate(module, device_ids)
...@@ -209,64 +144,10 @@ class CriterionDataParallel(Module): ...@@ -209,64 +144,10 @@ class CriterionDataParallel(Module):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def parallel_apply(self, replicas, inputs, targets, kwargs): def parallel_apply(self, replicas, inputs, targets, kwargs):
return criterion_parallel_apply(replicas, inputs, targets, kwargs) return _criterion_parallel_apply(replicas, inputs, targets, kwargs)
def gather(self, outputs, output_device):
return gather(outputs, output_device, dim=self.dim).mean()
class SelfDataParallel(Module):
"""SelfDataParallel, please make sure you understand it before using.
Reference: def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None):
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, Amit Agrawal. “Context Encoding for Semantic Segmentation. CVPR 2018
Each module in the network should be in self-parallel mode,
which allows list of inputs from multiple GPUs.
Please see :class:`encoding.nn` for detail, use with cautious
"""
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(SelfDataParallel, self).__init__()
if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
if output_device is None:
output_device = device_ids[0]
self.dim = dim
self.module = module
self.device_ids = device_ids
self.output_device = output_device
self.master_mean, self.master_var = {}, {}
if len(self.device_ids) == 1:
self.module.cuda(device_ids[0])
def forward(self, *inputs, **kwargs):
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if self.training:
# self parallel mode
outputs = self.module(inputs)
return outputs
else:
# TODO check faster?
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, \
self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
return outputs
def replicate(self, module, device_ids):
return replicate(module, device_ids)
def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs)
def scatter(self, inputs, kwargs, device_ids):
outputs = scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
return outputs
def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None):
assert len(modules) == len(inputs) assert len(modules) == len(inputs)
assert len(targets) == len(inputs) assert len(targets) == len(inputs)
if kwargs_tup: if kwargs_tup:
...@@ -281,13 +162,8 @@ def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None): ...@@ -281,13 +162,8 @@ def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None):
results = {} results = {}
def _worker(i, module, input, target, kwargs, results, lock): def _worker(i, module, input, target, kwargs, results, lock):
var_input = input
while not isinstance(var_input, Variable):
var_input = var_input[0]
var_target = target
while not isinstance(var_target, Variable):
var_target = var_target[0]
try: try:
var_input = _get_a_var(input)
with torch.cuda.device_of(var_input): with torch.cuda.device_of(var_input):
output = module(input, *target, **kwargs) output = module(input, *target, **kwargs)
with lock: with lock:
...@@ -297,9 +173,8 @@ def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None): ...@@ -297,9 +173,8 @@ def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None):
results[i] = e results[i] = e
threads = [threading.Thread(target=_worker, threads = [threading.Thread(target=_worker,
args=(i, module, input, target, args=(i, module, input, target,
kwargs, results, lock), kwargs, results, lock),)
)
for i, (module, input, target, kwargs) in for i, (module, input, target, kwargs) in
enumerate(zip(modules, inputs, targets, kwargs_tup))] enumerate(zip(modules, inputs, targets, kwargs_tup))]
...@@ -316,77 +191,18 @@ def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None): ...@@ -316,77 +191,18 @@ def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None):
return outputs return outputs
def get_a_var(obj): def _get_a_var(obj):
if isinstance(obj, Variable): if isinstance(obj, Variable):
return obj return obj
if isinstance(obj, list) or isinstance(obj, tuple): if isinstance(obj, list) or isinstance(obj, tuple):
results = map(get_a_var, obj) results = map(_get_a_var, obj)
for result in results: for result in results:
if isinstance(result, Variable): if isinstance(result, Variable):
return result return result
if isinstance(obj, dict): if isinstance(obj, dict):
results = map(get_a_var, obj.items()) results = map(_get_a_var, obj.items())
for result in results: for result in results:
if isinstance(result, Variable): if isinstance(result, Variable):
return result return result
return None return None
def my_parallel_apply(modules, inputs, kwargs_tup=None):
assert len(modules) == len(inputs)
if kwargs_tup:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = ({},) * len(modules)
# Fast track
if len(modules) == 1:
return (modules[0](*inputs[0], **kwargs_tup[0]), )
lock = threading.Lock()
results = {}
def _worker(i, module, input, kwargs, results, lock):
var_input = get_a_var(input)
try:
with torch.cuda.device_of(var_input):
output = module(input, **kwargs)
with lock:
results[i] = output
except Exception as e:
with lock:
results[i] = e
threads = [threading.Thread(target=_worker,
args=(i, module, input, kwargs, results, lock),
)
for i, (module, input, kwargs) in
enumerate(zip(modules, inputs, kwargs_tup))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, Exception):
raise output
outputs.append(output)
return outputs
def my_data_parallel(module, inputs, device_ids=None, \
dim=0, module_kwargs=None):
if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
if len(inputs) == 1:
return module(inputs[0])
#print('my data parallel, len(inputs)', len(inputs))
replicas = replicate(module, device_ids[:len(inputs)])
outputs = my_parallel_apply(replicas, inputs, module_kwargs)
return outputs
...@@ -5,22 +5,20 @@ ...@@ -5,22 +5,20 @@
## Copyright (c) 2017 ## Copyright (c) 2017
## ##
## This source code is licensed under the MIT-style license found in the ## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree ## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import torch """Encoding Util Tools"""
import shutil import shutil
import os import os
import sys
import time
import math import math
import tqdm import torch
__all__ = ['get_optimizer', 'LR_Scheduler', 'save_checkpoint', 'progress_bar'] __all__ = ['get_optimizer', 'LR_Scheduler', 'save_checkpoint']
def get_optimizer(args, model, diff_LR=True): def get_optimizer(args, model, diff_LR=True):
""" """
Returns an optimizer for given model, Returns an optimizer for given model,
Args: Args:
args: :attr:`args.lr`, :attr:`args.momentum`, :attr:`args.weight_decay` args: :attr:`args.lr`, :attr:`args.momentum`, :attr:`args.weight_decay`
...@@ -29,17 +27,17 @@ def get_optimizer(args, model, diff_LR=True): ...@@ -29,17 +27,17 @@ def get_optimizer(args, model, diff_LR=True):
if diff_LR and model.pretrained is not None: if diff_LR and model.pretrained is not None:
print('Using different learning rate for pre-trained features') print('Using different learning rate for pre-trained features')
optimizer = torch.optim.SGD([ optimizer = torch.optim.SGD([
{'params': model.pretrained.parameters()}, {'params': model.pretrained.parameters()},
{'params': model.head.parameters(), {'params': model.head.parameters(),
'lr': args.lr*10}, 'lr': args.lr*10},
], ],
lr=args.lr, lr=args.lr,
momentum=args.momentum, momentum=args.momentum,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
else: else:
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
momentum=args.momentum, momentum=args.momentum,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
return optimizer return optimizer
...@@ -53,12 +51,14 @@ class LR_Scheduler(object): ...@@ -53,12 +51,14 @@ class LR_Scheduler(object):
Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9``
Args: Args:
args: :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, :attr:`args.lr_step` args: :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`),
:attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs,
:attr:`args.lr_step`
niters: number of iterations per epoch niters: number of iterations per epoch
""" """
def __init__(self, args, niters=0): def __init__(self, args, niters=0):
self.mode = args.lr_scheduler self.mode = args.lr_scheduler
print('Using {} LR Scheduler!'.format(self.mode)) print('Using {} LR Scheduler!'.format(self.mode))
self.lr = args.lr self.lr = args.lr
if self.mode == 'step': if self.mode == 'step':
...@@ -81,8 +81,7 @@ class LR_Scheduler(object): ...@@ -81,8 +81,7 @@ class LR_Scheduler(object):
raise RuntimeError('Unknown LR scheduler!') raise RuntimeError('Unknown LR scheduler!')
if epoch > self.epoch: if epoch > self.epoch:
print('\n=>Epoches %i, learning rate = %.4f, \ print('\n=>Epoches %i, learning rate = %.4f, \
previous best = %.4f' % ( previous best = %.4f' % (epoch, lr, best_pred))
epoch, lr, best_pred))
self.epoch = epoch self.epoch = epoch
self._adjust_learning_rate(optimizer, lr) self._adjust_learning_rate(optimizer, lr)
...@@ -92,7 +91,7 @@ class LR_Scheduler(object): ...@@ -92,7 +91,7 @@ class LR_Scheduler(object):
else: else:
# enlarge the lr at the head # enlarge the lr at the head
optimizer.param_groups[0]['lr'] = lr optimizer.param_groups[0]['lr'] = lr
for i in range(1,len(optimizer.param_groups)): for i in range(1, len(optimizer.param_groups)):
optimizer.param_groups[i]['lr'] = lr * 10 optimizer.param_groups[i]['lr'] = lr * 10
...@@ -106,88 +105,3 @@ def save_checkpoint(state, args, is_best, filename='checkpoint.pth.tar'): ...@@ -106,88 +105,3 @@ def save_checkpoint(state, args, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename) torch.save(state, filename)
if is_best: if is_best:
shutil.copyfile(filename, directory + 'model_best.pth.tar') shutil.copyfile(filename, directory + 'model_best.pth.tar')
# refer to https://github.com/kuangliu/pytorch-cifar/blob/master/utils.py
_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)-1
TOTAL_BAR_LENGTH = 36.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
"""Progress Bar for display
"""
global last_time, begin_time
if current == 0:
begin_time = time.time() # Reset for new bar.
cur_len = int(TOTAL_BAR_LENGTH*current/total)
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
sys.stdout.write(' [')
for i in range(cur_len):
sys.stdout.write('=')
sys.stdout.write('>')
for i in range(rest_len):
sys.stdout.write('.')
sys.stdout.write(']')
cur_time = time.time()
step_time = cur_time - last_time
last_time = cur_time
tot_time = cur_time - begin_time
L = []
L.append(' Step: %s' % _format_time(step_time))
L.append(' | Tot: %s' % _format_time(tot_time))
if msg:
L.append(' | ' + msg)
msg = ''.join(L)
sys.stdout.write(msg)
for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
sys.stdout.write(' ')
# Go back to the center of the bar.
for i in range(term_width-int(TOTAL_BAR_LENGTH/2)):
sys.stdout.write('\b')
sys.stdout.write(' %d/%d ' % (current+1, total))
if current < total-1:
sys.stdout.write('\r')
else:
sys.stdout.write('\n')
sys.stdout.flush()
def _format_time(seconds):
days = int(seconds / 3600/24)
seconds = seconds - days*3600*24
hours = int(seconds / 3600)
seconds = seconds - hours*3600
minutes = int(seconds / 60)
seconds = seconds - minutes*60
secondsf = int(seconds)
seconds = seconds - secondsf
millis = int(seconds*1000)
f = ''
i = 1
if days > 0:
f += str(days) + 'D'
i += 1
if hours > 0 and i <= 2:
f += str(hours) + 'h'
i += 1
if minutes > 0 and i <= 2:
f += str(minutes) + 'm'
i += 1
if secondsf > 0 and i <= 2:
f += str(secondsf) + 's'
i += 1
if millis > 0 and i <= 2:
f += str(millis) + 'ms'
i += 1
if f == '':
f = '0ms'
return f
...@@ -23,16 +23,17 @@ class install(setuptools.command.install.install): ...@@ -23,16 +23,17 @@ class install(setuptools.command.install.install):
def run(self): def run(self):
self.create_version_file() self.create_version_file()
setuptools.command.install.install.run(self) setuptools.command.install.install.run(self)
subprocess.check_call("python test/test.py".split()) subprocess.check_call("python tests/unit_test.py".split())
@staticmethod @staticmethod
def create_version_file(): def create_version_file():
global version, cwd global version, cwd
print('-- Building version ' + version) print('-- Building version ' + version)
version_path = os.path.join(cwd, 'encoding', 'version.py') version_path = os.path.join(cwd, 'encoding', 'version.py')
with open(version_path, 'w') as f: with open(version_path, 'w') as f:
f.write('"""This is encoding version file."""\n')
f.write("__version__ = '{}'\n".format(version)) f.write("__version__ = '{}'\n".format(version))
version = '0.2.0' version = '0.3.0'
try: try:
sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
cwd=cwd).decode('ascii').strip() cwd=cwd).decode('ascii').strip()
......
#!/usr/bin/env python
# pylint: disable=protected-access, unused-variable, locally-disabled, len-as-condition
"""Lint helper to generate lint summary of source.
Copyright by Contributors
"""
from __future__ import print_function
import argparse
import codecs
import sys
import re
import os
import cpplint
from cpplint import _cpplint_state
from pylint import epylint
CXX_SUFFIX = set(['cc', 'c', 'cpp', 'h', 'cu', 'hpp'])
PYTHON_SUFFIX = set(['py'])
def filepath_enumerate(paths):
"""Enumerate the file paths of all subfiles of the list of paths"""
out = []
for path in paths:
if os.path.isfile(path):
out.append(path)
else:
for root, dirs, files in os.walk(path):
for name in files:
out.append(os.path.normpath(os.path.join(root, name)))
return out
class LintHelper(object):
"""Class to help runing the lint and records summary"""
@staticmethod
def _print_summary_map(strm, result_map, ftype):
"""Print summary of certain result map."""
if len(result_map) == 0:
return 0
npass = len([x for k, x in result_map.items() if len(x) == 0])
strm.write('=====%d/%d %s files passed check=====\n' % (npass, len(result_map), ftype))
for fname, emap in result_map.items():
if len(emap) == 0:
continue
strm.write('%s: %d Errors of %d Categories map=%s\n' % (
fname, sum(emap.values()), len(emap), str(emap)))
return len(result_map) - npass
def __init__(self):
self.project_name = None
self.cpp_header_map = {}
self.cpp_src_map = {}
self.python_map = {}
pylint_disable = ['superfluous-parens',
'too-many-instance-attributes',
'too-few-public-methods']
# setup pylint
self.pylint_opts = ['--extension-pkg-whitelist=numpy',
'--disable=' + ','.join(pylint_disable)]
self.pylint_cats = set(['error', 'warning', 'convention', 'refactor'])
# setup cpp lint
cpplint_args = ['.', '--extensions=' + (','.join(CXX_SUFFIX))]
_ = cpplint.ParseArguments(cpplint_args)
cpplint._SetFilters(','.join(['-build/c++11',
'-build/namespaces',
'-build/include,',
'+build/include_what_you_use',
'+build/include_order']))
cpplint._SetCountingStyle('toplevel')
cpplint._line_length = 100
def process_cpp(self, path, suffix):
"""Process a cpp file."""
_cpplint_state.ResetErrorCounts()
cpplint.ProcessFile(str(path), _cpplint_state.verbose_level)
_cpplint_state.PrintErrorCounts()
errors = _cpplint_state.errors_by_category.copy()
if suffix == 'h':
self.cpp_header_map[str(path)] = errors
else:
self.cpp_src_map[str(path)] = errors
def process_python(self, path):
"""Process a python file."""
(pylint_stdout, pylint_stderr) = epylint.py_run(
' '.join([str(path)] + self.pylint_opts), return_std=True)
emap = {}
err = pylint_stderr.read()
if len(err):
print(err)
for line in pylint_stdout:
sys.stderr.write(line)
key = line.split(':')[-1].split('(')[0].strip()
if key not in self.pylint_cats:
continue
if key not in emap:
emap[key] = 1
else:
emap[key] += 1
self.python_map[str(path)] = emap
def print_summary(self, strm):
"""Print summary of lint."""
nerr = 0
nerr += LintHelper._print_summary_map(strm, self.cpp_header_map, 'cpp-header')
nerr += LintHelper._print_summary_map(strm, self.cpp_src_map, 'cpp-soruce')
nerr += LintHelper._print_summary_map(strm, self.python_map, 'python')
if nerr == 0:
strm.write('All passed!\n')
else:
strm.write('%d files failed lint\n' % nerr)
return nerr
# singleton helper for lint check
_HELPER = LintHelper()
def get_header_guard_dmlc(filename):
"""Get Header Guard Convention for DMLC Projects.
For headers in include, directly use the path
For headers in src, use project name plus path
Examples: with project-name = dmlc
include/dmlc/timer.h -> DMLC_TIMTER_H_
src/io/libsvm_parser.h -> DMLC_IO_LIBSVM_PARSER_H_
"""
fileinfo = cpplint.FileInfo(filename)
file_path_from_root = fileinfo.RepositoryName()
inc_list = ['include', 'api', 'wrapper', 'contrib']
if os.name == 'nt':
inc_list.append("mshadow")
if file_path_from_root.find('src/') != -1 and _HELPER.project_name is not None:
idx = file_path_from_root.find('src/')
file_path_from_root = _HELPER.project_name + file_path_from_root[idx + 3:]
else:
idx = file_path_from_root.find("include/")
if idx != -1:
file_path_from_root = file_path_from_root[idx + 8:]
for spath in inc_list:
prefix = spath + '/'
if file_path_from_root.startswith(prefix):
file_path_from_root = re.sub('^' + prefix, '', file_path_from_root)
break
return re.sub(r'[-./\s]', '_', file_path_from_root).upper() + '_'
cpplint.GetHeaderGuardCPPVariable = get_header_guard_dmlc
def process(fname, allow_type):
"""Process a file."""
fname = str(fname)
arr = fname.rsplit('.', 1)
if fname.find('#') != -1 or arr[-1] not in allow_type:
return
if arr[-1] in CXX_SUFFIX:
_HELPER.process_cpp(fname, arr[-1])
if arr[-1] in PYTHON_SUFFIX:
_HELPER.process_python(fname)
def main():
"""Main entry function."""
parser = argparse.ArgumentParser(description="lint source codes")
parser.add_argument('project', help='project name')
parser.add_argument('filetype', choices=['python', 'cpp', 'all'],
help='source code type')
parser.add_argument('path', nargs='+', help='path to traverse')
parser.add_argument('--exclude_path', nargs='+', default=[],
help='exclude this path, and all subfolders if path is a folder')
parser.add_argument('--pylint-rc', default=None,
help='pylint rc file')
args = parser.parse_args()
_HELPER.project_name = args.project
if args.pylint_rc is not None:
_HELPER.pylint_opts = ['--rcfile='+args.pylint_rc,]
file_type = args.filetype
allow_type = []
if file_type == 'python' or file_type == 'all':
allow_type += [x for x in PYTHON_SUFFIX]
if file_type == 'cpp' or file_type == 'all':
allow_type += [x for x in CXX_SUFFIX]
allow_type = set(allow_type)
if sys.version_info.major == 2 and os.name != 'nt':
sys.stderr = codecs.StreamReaderWriter(sys.stderr,
codecs.getreader('utf8'),
codecs.getwriter('utf8'),
'replace')
# get excluded files
excluded_paths = filepath_enumerate(args.exclude_path)
for path in args.path:
if os.path.isfile(path):
normpath = os.path.normpath(path)
if normpath not in excluded_paths:
process(path, allow_type)
else:
for root, dirs, files in os.walk(path):
for name in files:
file_path = os.path.normpath(os.path.join(root, name))
if file_path not in excluded_paths:
process(file_path, allow_type)
nerr = _HELPER.print_summary(sys.stderr)
sys.exit(nerr > 0)
if __name__ == '__main__':
main()
[MASTER]
# Specify a configuration file.
#rcfile=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS
# Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths.
ignore-patterns=
# Pickle collected data for later comparisons.
persistent=yes
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Use multiple processes to speed up Pylint.
jobs=8
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code
extension-pkg-whitelist=numpy,opencv
# Allow optimization of some AST trees. This will activate a peephole AST
# optimizer, which will apply various small optimizations. For instance, it can
# be used to obtain the result of joining multiple strings with the addition
# operator. Joining a lot of strings can lead to a maximum recursion error in
# Pylint and this flag can prevent that. It has one side effect, the resulting
# AST will be different than the one from reality. This option is deprecated
# and it will be removed in Pylint 2.0.
optimize-ast=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
confidence=
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=indexing-exception,old-raise-syntax
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,protected-access,superfluous-parens,invalid-name,no-else-return,useless-super-delegation,len-as-condition,invalid-unary-operand-type,line-too-long,arguments-differ,redefined-builtin,wildcard-import,broad-except,consider-using-enumerate
# disable=unicode-builtin,delslice-method,using-cmp-argument,setslice-method,dict-view-method,parameter-unpacking,range-builtin-not-iterating,print-statement,file-builtin,old-raise-syntax,basestring-builtin,execfile-builtin,indexing-exception,import-star-module-level,coerce-method,long-builtin,old-ne-operator,old-division,no-absolute-import,raw_input-builtin,old-octal-literal,oct-method,xrange-builtin,hex-method,unpacking-in-except,nonzero-method,raising-string,intern-builtin,reload-builtin,metaclass-assignment,cmp-method,filter-builtin-not-iterating,apply-builtin,map-builtin-not-iterating,next-method-called,unichr-builtin,buffer-builtin,dict-iter-method,input-builtin,coerce-builtin,getslice-method,useless-suppression,standarderror-builtin,zip-builtin-not-iterating,suppressed-message,cmp-builtin,backtick,long-suffix,reduce-builtin,round-builtin
[REPORTS]
# Set the output format. Available formats are text, parseable, colorized, msvs
# (visual studio) and html. You can also give a reporter class, eg
# mypackage.mymodule.MyReporterClass.
output-format=text
# Put messages in a separate file for each module / package specified on the
# command line instead of printing them on stdout. Reports (if any) will be
# written in a file name "pylint_global.[txt|html]". This option is deprecated
# and it will be removed in Pylint 2.0.
files-output=no
# Tells whether to display a full report or only the messages
reports=no
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details
#msg-template=
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=100
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,dict-separator
# Maximum number of lines in a module
max-module-lines=1000
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
[SPELLING]
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package.
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,XXX,TODO
[TYPECHECK]
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
[LOGGING]
# Logging modules to check that the string format arguments are in logging
# function parameter format
logging-modules=logging
[SIMILARITIES]
# Minimum lines number of a similarity.
min-similarity-lines=4
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
[VARIABLES]
# Tells whether we should check for unused import in __init__ files.
init-import=no
# A regular expression matching the name of dummy variables (i.e. expectedly
# not used).
dummy-variables-rgx=(_+[a-zA-Z0-9]*?$)|dummy
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid to define new builtins when possible.
additional-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,_cb
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,future.builtins
[BASIC]
# Good variable names which should always be accepted, separated by a comma
good-names=i,j,_,a,b,op,x,y,wd,lr,kv,k,v,s,p,h,c,m,n,X,t,g,f
# Bad variable names which should always be refused, separated by a comma
bad-names=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Include a hint for the correct naming format with invalid-name
include-naming-hint=no
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
property-classes=abc.abstractproperty
# Regular expression matching correct module names
module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
# Naming hint for module names
module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
# Regular expression matching correct constant names
const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
# Naming hint for constant names
const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$
# Regular expression matching correct inline iteration names
inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
# Naming hint for inline iteration names
inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
# Regular expression matching correct method names
method-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for method names
method-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct class attribute names
class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
# Naming hint for class attribute names
class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
# Regular expression matching correct argument names
argument-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for argument names
argument-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct attribute names
attr-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for attribute names
attr-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct variable names
variable-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for variable names
variable-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct function names
function-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for function names
function-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct class names
class-rgx=[A-Z_][a-zA-Z0-9]+$
# Naming hint for class names
class-name-hint=[A-Z_][a-zA-Z0-9]+$
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=10
[ELIF]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,__new__,setUp
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,_fields,_replace,_source,_make
[IMPORTS]
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=optparse
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled)
import-graph=
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled)
ext-import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled)
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
[DESIGN]
# Maximum number of arguments for function / method
max-args=5
# Argument names that match this expression will be ignored. Default to name
# with leading underscore
ignored-argument-names=_.*
# Maximum number of locals for function / method body
max-locals=15
# Maximum number of return / yield for function / method body
max-returns=6
# Maximum number of branch for function / method body
max-branches=12
# Maximum number of statements in function / method body
max-statements=50
# Maximum number of parents for a class (see R0901).
max-parents=7
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of boolean expressions in a if statement
max-bool-expr=5
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=Exception
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import encoding import encoding
import unittest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable, gradcheck from torch.autograd import Variable, gradcheck
...@@ -51,7 +53,7 @@ def test_encoding(): ...@@ -51,7 +53,7 @@ def test_encoding():
layer = encoding.nn.Encoding(C,K).double().cuda() layer = encoding.nn.Encoding(C,K).double().cuda()
test = gradcheck(layer, input, eps=1e-6, atol=1e-4) test = gradcheck(layer, input, eps=1e-6, atol=1e-4)
print('Testing encoding(): {}'.format(test)) print('Testing encoding(): {}'.format(test))
def test_sum_square(): def test_sum_square():
B,C,H,W = 2,3,4,5 B,C,H,W = 2,3,4,5
...@@ -62,17 +64,15 @@ def test_sum_square(): ...@@ -62,17 +64,15 @@ def test_sum_square():
print('Testing sum_square(): {}'.format(test)) print('Testing sum_square(): {}'.format(test))
def test_dilated_avgpool(): def test_all_reduce():
X = Variable(torch.cuda.FloatTensor(1,3,75,75).uniform_(-0.5,0.5)) ngpu = torch.cuda.device_count()
input = (X,) X = [torch.DoubleTensor(2,4,4).uniform_(-0.5,0.5).cuda(i) for i in range(ngpu)]
layer = encoding.nn.DilatedAvgPool2d(kernel_size=2, stride=1, padding=0, dilation=2) for x in X:
test = gradcheck(layer, input, eps=1e-6, atol=1e-4) x.requires_grad = True
print('Testing dilatedavgpool2d(): {}'.format(test)) Y = encoding.parallel.allreduce(*X)
assert (len(X) == len(Y))
if __name__ == '__main__': if __name__ == '__main__':
test_scaledL2() import nose
test_encoding() nose.runmodule()
test_aggregate()
test_sum_square()
test_dilated_avgpool()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment