Commit ce2e18d7 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

spatial size calc for activations

parent 159e5f9a
...@@ -20,7 +20,8 @@ class Sigmoid(Module): ...@@ -20,7 +20,8 @@ class Sigmoid(Module):
output.metadata = input.metadata output.metadata = input.metadata
output.spatial_size = input.spatial_size output.spatial_size = input.spatial_size
return output return output
def input_spatial_size(self, out_size):
return out_size
class LeakyReLU(Module): class LeakyReLU(Module):
def __init__(self,leak=1/3): def __init__(self,leak=1/3):
...@@ -32,6 +33,8 @@ class LeakyReLU(Module): ...@@ -32,6 +33,8 @@ class LeakyReLU(Module):
output.metadata = input.metadata output.metadata = input.metadata
output.spatial_size = input.spatial_size output.spatial_size = input.spatial_size
return output return output
def input_spatial_size(self, out_size):
return out_size
class Tanh(Module): class Tanh(Module):
...@@ -41,6 +44,8 @@ class Tanh(Module): ...@@ -41,6 +44,8 @@ class Tanh(Module):
output.metadata = input.metadata output.metadata = input.metadata
output.spatial_size = input.spatial_size output.spatial_size = input.spatial_size
return output return output
def input_spatial_size(self, out_size):
return out_size
class ReLU(Module): class ReLU(Module):
...@@ -50,6 +55,8 @@ class ReLU(Module): ...@@ -50,6 +55,8 @@ class ReLU(Module):
output.metadata = input.metadata output.metadata = input.metadata
output.spatial_size = input.spatial_size output.spatial_size = input.spatial_size
return output return output
def input_spatial_size(self, out_size):
return out_size
class ELU(Module): class ELU(Module):
...@@ -59,6 +66,8 @@ class ELU(Module): ...@@ -59,6 +66,8 @@ class ELU(Module):
output.metadata = input.metadata output.metadata = input.metadata
output.spatial_size = input.spatial_size output.spatial_size = input.spatial_size
return output return output
def input_spatial_size(self, out_size):
return out_size
class SELU(Module): class SELU(Module):
def forward(self, input): def forward(self, input):
...@@ -67,6 +76,8 @@ class SELU(Module): ...@@ -67,6 +76,8 @@ class SELU(Module):
output.metadata = input.metadata output.metadata = input.metadata
output.spatial_size = input.spatial_size output.spatial_size = input.spatial_size
return output return output
def input_spatial_size(self, out_size):
return out_size
def BatchNormELU(nPlanes, eps=1e-4, momentum=0.9): def BatchNormELU(nPlanes, eps=1e-4, momentum=0.9):
return sparseconvnet.Sequential().add(BatchNormalization(nPlanes,eps,momentum)).add(ELU()) return sparseconvnet.Sequential().add(BatchNormalization(nPlanes,eps,momentum)).add(ELU())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment