Commit d70dee87 authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

removed docstrings from _init_weights

parent b17bfdc4
...@@ -171,14 +171,6 @@ class DownsamplingBlock(nn.Module): ...@@ -171,14 +171,6 @@ class DownsamplingBlock(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=.02) nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None: if m.bias is not None:
...@@ -344,14 +336,6 @@ class UpsamplingBlock(nn.Module): ...@@ -344,14 +336,6 @@ class UpsamplingBlock(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=.02) nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None: if m.bias is not None:
......
...@@ -117,14 +117,7 @@ class OverlapPatchMerging(nn.Module): ...@@ -117,14 +117,7 @@ class OverlapPatchMerging(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : nn.Module
Module to initialize
"""
if isinstance(m, nn.LayerNorm): if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.weight, 1.0)
...@@ -230,14 +223,7 @@ class MixFFN(nn.Module): ...@@ -230,14 +223,7 @@ class MixFFN(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : nn.Module
Module to initialize
"""
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02) nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None: if m.bias is not None:
...@@ -792,14 +778,7 @@ class SphericalSegformer(nn.Module): ...@@ -792,14 +778,7 @@ class SphericalSegformer(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : nn.Module
Module to initialize
"""
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02) nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None: if m.bias is not None:
......
...@@ -194,14 +194,7 @@ class DownsamplingBlock(nn.Module): ...@@ -194,14 +194,7 @@ class DownsamplingBlock(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : nn.Module
Module to initialize
"""
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02) nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None: if m.bias is not None:
...@@ -585,14 +578,7 @@ class SphericalUNet(nn.Module): ...@@ -585,14 +578,7 @@ class SphericalUNet(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : nn.Module
Module to initialize
"""
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02) nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None: if m.bias is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment