Commit 9f6ebfe7 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

SparseGroupNorm

parent 875a43fd
......@@ -8,7 +8,7 @@ forward_pass_multiplyAdd_count = 0
forward_pass_hidden_states = 0
from .activations import Tanh, Sigmoid, ReLU, LeakyReLU, ELU, SELU, BatchNormELU
from .averagePooling import AveragePooling
from .batchNormalization import BatchNormalization, BatchNormReLU, BatchNormLeakyReLU, MeanOnlyBNLeakyReLU
from .batchNormalization import BatchNormalization, BatchNormReLU, BatchNormLeakyReLU, MeanOnlyBNLeakyReLU, SparseGroupNorm
from .classificationTrainValidate import ClassificationTrainValidate
from .convolution import Convolution
from .deconvolution import Deconvolution
......
......@@ -208,3 +208,11 @@ class MeanOnlyBNLeakyReLU(Module):
def __repr__(self):
s = 'MeanOnlyBatchNorm(' + str(self.nPlanes) + ',momentum=' + str(self.momentum) + ',leakiness=' + str(self.leakiness) + ')'
return s
class SparseGroupNorm(torch.nn.GroupNorm):
def forward(self,x):
return scn.SparseConvNetTensor(
super().forward(x.features),
x.metadata,
x.spatial_size)
......@@ -335,3 +335,10 @@ class VerboseIdentity(torch.nn.Module):
def forward(self, x):
print(x)
return x
class SparseGroupNorm(torch.nn.GroupNorm):
def forward(self,x):
return scn.SparseConvNetTensor(
super().forward(x.features),
x.metadata,
x.spatial_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment