Commit d8c3aff1 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

tidy

parent 5f0860fc
...@@ -10,7 +10,7 @@ from .inputBatch import InputBatch ...@@ -10,7 +10,7 @@ from .inputBatch import InputBatch
from .sparseConvNetTensor import SparseConvNetTensor from .sparseConvNetTensor import SparseConvNetTensor
from .sparseModule import SparseModule from .sparseModule import SparseModule
from .averagePooling import AveragePooling from .averagePooling import AveragePooling
from .batchNormalization import BatchNormReLU, BatchNormLeakyReLU, BatchNormalizationInTensor from .batchNormalization import BatchNormalization, BatchNormReLU, BatchNormLeakyReLU, BatchNormalizationInTensor
from .concatTable import ConcatTable from .concatTable import ConcatTable
from .convolution import Convolution from .convolution import Convolution
from .cAddTable import CAddTable from .cAddTable import CAddTable
...@@ -18,7 +18,7 @@ from .deconvolution import Deconvolution ...@@ -18,7 +18,7 @@ from .deconvolution import Deconvolution
from .denseToSparse import DenseToSparse from .denseToSparse import DenseToSparse
from .identity import Identity from .identity import Identity
from .joinTable import JoinTable from .joinTable import JoinTable
from .leakyReLU import LeakyReLU, Tanh from .leakyReLU import LeakyReLU
from .maxPooling import MaxPooling from .maxPooling import MaxPooling
from .networkInNetwork import NetworkInNetwork from .networkInNetwork import NetworkInNetwork
from .reLU import ReLU from .reLU import ReLU
...@@ -27,3 +27,4 @@ from .sparseToDense import SparseToDense ...@@ -27,3 +27,4 @@ from .sparseToDense import SparseToDense
from .validConvolution import ValidConvolution from .validConvolution import ValidConvolution
from .networkArchitectures import * from .networkArchitectures import *
from .classificationTrainValidate import ClassificationTrainValidate from .classificationTrainValidate import ClassificationTrainValidate
from .misc import *
...@@ -21,7 +21,6 @@ from . import SparseModule ...@@ -21,7 +21,6 @@ from . import SparseModule
from ..utils import toLongTensor, typed_fn, optionalTensor, nullptr from ..utils import toLongTensor, typed_fn, optionalTensor, nullptr
from .sparseConvNetTensor import SparseConvNetTensor from .sparseConvNetTensor import SparseConvNetTensor
class BatchNormalization(SparseModule): class BatchNormalization(SparseModule):
def __init__( def __init__(
self, self,
...@@ -122,7 +121,6 @@ class BatchNormLeakyReLU(BatchNormalization): ...@@ -122,7 +121,6 @@ class BatchNormLeakyReLU(BatchNormalization):
',momentum=' + str(self.momentum) + ',affine=' + str(self.affine) + ')' ',momentum=' + str(self.momentum) + ',affine=' + str(self.affine) + ')'
return s return s
class BatchNormalizationInTensor(BatchNormalization): class BatchNormalizationInTensor(BatchNormalization):
def __init__( def __init__(
self, self,
......
...@@ -42,28 +42,3 @@ class LeakyReLU(SparseModule): ...@@ -42,28 +42,3 @@ class LeakyReLU(SparseModule):
if t: if t:
self.output.type(t) self.output.type(t)
self.gradInput = self.gradInput.type(t) self.gradInput = self.gradInput.type(t)
class Tanh(SparseModule):
def __init__(self):
SparseModule.__init__(self)
self.output = SparseConvNetTensor(torch.Tensor())
#self.gradInput = None if ip else torch.Tensor()
self.gradInput = torch.Tensor()
def updateOutput(self, input):
self.output.metadata = input.metadata
self.output.spatial_size = input.spatial_size
self.output.features=torch.tanh(input.features)
return self.output
def updateGradInput(self, input, gradOutput):
self.gradInput.resize_as_(gradOutput).copy_(gradOutput)
self.gradInput.mul(1+self.output.features)
self.gradInput.mul(1-self.output.features)
return self.gradInput
def type(self, t, tensorCache=None):
if t:
self.output.type(t)
self.gradInput = self.gradInput.type(t)
# Copyright 2016-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch.legacy.nn as nn
from .sequential import Sequential
from .sparseModule import SparseModule
class Tanh(SparseModule):
def __init__(self):
SparseModule.__init__(self)
self.module=nn.Tanh()
self.output = SparseConvNetTensor()
self.output.features=self.module.output
self.gradInput = self.module.gradInput
def updateOutput(self, input):
self.output.metadata = input.metadata
self.output.spatial_size = input.spatial_size
self.module.forward(input.features)
return self.output
def updateGradInput(self, input, gradOutput):
self.module.updateGradInput(input.features,gradOutput)
return self.gradInput
def type(self, t, tensorCache=None):
if t:
self.module.type(t,tensorCache)
self.output.features=self.module.output
self.gradInput = self.module.gradInput
class ELU(SparseModule):
def __init__(self):
SparseModule.__init__(self)
self.module=nn.ELU()
self.output = SparseConvNetTensor()
self.gradInput = self.module.gradInput
def updateOutput(self, input):
self.output.metadata = input.metadata
self.output.spatial_size = input.spatial_size
self.module.forward(input.features)
return self.output
def updateGradInput(self, input, gradOutput):
self.module.updateGradInput(input.features,gradOutput)
return self.gradInput
def type(self, t, tensorCache=None):
if t:
self.module.type(t,tensorCache)
self.output.features=self.module.output
self.gradInput = self.module.gradInput
def BatchNormELU(nPlanes, eps=1e-4, momentum=0.9):
return Sequential().add(BatchNormalization(nPlanes,eps,momentum)).add(ELU())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment