Commit d8c3aff1 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

tidy

parent 5f0860fc
......@@ -10,7 +10,7 @@ from .inputBatch import InputBatch
from .sparseConvNetTensor import SparseConvNetTensor
from .sparseModule import SparseModule
from .averagePooling import AveragePooling
from .batchNormalization import BatchNormReLU, BatchNormLeakyReLU, BatchNormalizationInTensor
from .batchNormalization import BatchNormalization, BatchNormReLU, BatchNormLeakyReLU, BatchNormalizationInTensor
from .concatTable import ConcatTable
from .convolution import Convolution
from .cAddTable import CAddTable
......@@ -18,7 +18,7 @@ from .deconvolution import Deconvolution
from .denseToSparse import DenseToSparse
from .identity import Identity
from .joinTable import JoinTable
from .leakyReLU import LeakyReLU, Tanh
from .leakyReLU import LeakyReLU
from .maxPooling import MaxPooling
from .networkInNetwork import NetworkInNetwork
from .reLU import ReLU
......@@ -27,3 +27,4 @@ from .sparseToDense import SparseToDense
from .validConvolution import ValidConvolution
from .networkArchitectures import *
from .classificationTrainValidate import ClassificationTrainValidate
from .misc import *
......@@ -21,7 +21,6 @@ from . import SparseModule
from ..utils import toLongTensor, typed_fn, optionalTensor, nullptr
from .sparseConvNetTensor import SparseConvNetTensor
class BatchNormalization(SparseModule):
def __init__(
self,
......@@ -122,7 +121,6 @@ class BatchNormLeakyReLU(BatchNormalization):
',momentum=' + str(self.momentum) + ',affine=' + str(self.affine) + ')'
return s
class BatchNormalizationInTensor(BatchNormalization):
def __init__(
self,
......
......@@ -42,28 +42,3 @@ class LeakyReLU(SparseModule):
if t:
self.output.type(t)
self.gradInput = self.gradInput.type(t)
class Tanh(SparseModule):
def __init__(self):
SparseModule.__init__(self)
self.output = SparseConvNetTensor(torch.Tensor())
#self.gradInput = None if ip else torch.Tensor()
self.gradInput = torch.Tensor()
def updateOutput(self, input):
self.output.metadata = input.metadata
self.output.spatial_size = input.spatial_size
self.output.features=torch.tanh(input.features)
return self.output
def updateGradInput(self, input, gradOutput):
self.gradInput.resize_as_(gradOutput).copy_(gradOutput)
self.gradInput.mul(1+self.output.features)
self.gradInput.mul(1-self.output.features)
return self.gradInput
def type(self, t, tensorCache=None):
if t:
self.output.type(t)
self.gradInput = self.gradInput.type(t)
# Copyright 2016-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch.legacy.nn as nn
from .sequential import Sequential
from .sparseModule import SparseModule
class Tanh(SparseModule):
def __init__(self):
SparseModule.__init__(self)
self.module=nn.Tanh()
self.output = SparseConvNetTensor()
self.output.features=self.module.output
self.gradInput = self.module.gradInput
def updateOutput(self, input):
self.output.metadata = input.metadata
self.output.spatial_size = input.spatial_size
self.module.forward(input.features)
return self.output
def updateGradInput(self, input, gradOutput):
self.module.updateGradInput(input.features,gradOutput)
return self.gradInput
def type(self, t, tensorCache=None):
if t:
self.module.type(t,tensorCache)
self.output.features=self.module.output
self.gradInput = self.module.gradInput
class ELU(SparseModule):
def __init__(self):
SparseModule.__init__(self)
self.module=nn.ELU()
self.output = SparseConvNetTensor()
self.gradInput = self.module.gradInput
def updateOutput(self, input):
self.output.metadata = input.metadata
self.output.spatial_size = input.spatial_size
self.module.forward(input.features)
return self.output
def updateGradInput(self, input, gradOutput):
self.module.updateGradInput(input.features,gradOutput)
return self.gradInput
def type(self, t, tensorCache=None):
if t:
self.module.type(t,tensorCache)
self.output.features=self.module.output
self.gradInput = self.module.gradInput
def BatchNormELU(nPlanes, eps=1e-4, momentum=0.9):
return Sequential().add(BatchNormalization(nPlanes,eps,momentum)).add(ELU())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment