Unverified Commit 2dd4d556 authored by Ofey Chan's avatar Ofey Chan Committed by GitHub
Browse files

[NFC] polish colossalai/nn/init.py code style (#1292)

parent 556b9b7e
...@@ -7,6 +7,7 @@ import torch.nn as nn ...@@ -7,6 +7,7 @@ import torch.nn as nn
def zeros_(): def zeros_():
"""Return the initializer filling the input Tensor with the scalar zeros""" """Return the initializer filling the input Tensor with the scalar zeros"""
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.zeros_(tensor) return nn.init.zeros_(tensor)
...@@ -15,6 +16,7 @@ def zeros_(): ...@@ -15,6 +16,7 @@ def zeros_():
def ones_(): def ones_():
"""Return the initializer filling the input Tensor with the scalar ones""" """Return the initializer filling the input Tensor with the scalar ones"""
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.ones_(tensor) return nn.init.ones_(tensor)
...@@ -46,6 +48,7 @@ def normal_(mean: float = 0., std: float = 1.): ...@@ -46,6 +48,7 @@ def normal_(mean: float = 0., std: float = 1.):
mean (float): the mean of the normal distribution. Defaults 0.0. mean (float): the mean of the normal distribution. Defaults 0.0.
std (float): the standard deviation of the normal distribution. Defaults 1.0. std (float): the standard deviation of the normal distribution. Defaults 1.0.
""" """
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.normal_(tensor, mean, std) return nn.init.normal_(tensor, mean, std)
...@@ -66,6 +69,7 @@ def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = ...@@ -66,6 +69,7 @@ def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float =
a (float): the minimum cutoff value. Defaults -2.0. a (float): the minimum cutoff value. Defaults -2.0.
b (float): the maximum cutoff value. Defaults 2.0. b (float): the maximum cutoff value. Defaults 2.0.
""" """
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.trunc_normal_(tensor, mean, std, a, b) return nn.init.trunc_normal_(tensor, mean, std, a, b)
...@@ -93,6 +97,7 @@ def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'): ...@@ -93,6 +97,7 @@ def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
nonlinearity (str, optional): the non-linear function (`nn.functional` name), nonlinearity (str, optional): the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
""" """
# adapted from torch.nn.init # adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
if 0 in tensor.shape: if 0 in tensor.shape:
...@@ -136,6 +141,7 @@ def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'): ...@@ -136,6 +141,7 @@ def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
nonlinearity (str, optional): the non-linear function (`nn.functional` name), nonlinearity (str, optional): the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
""" """
# adapted from torch.nn.init # adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
if 0 in tensor.shape: if 0 in tensor.shape:
...@@ -175,6 +181,7 @@ def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1 ...@@ -175,6 +181,7 @@ def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1
scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0. scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0.
gain (float, optional): an optional scaling factor. Defaults 1.0. gain (float, optional): an optional scaling factor. Defaults 1.0.
""" """
# adapted from torch.nn.init # adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.' assert fan_in is not None, 'Fan_in is not provided.'
...@@ -206,6 +213,7 @@ def xavier_normal_(scale: float = 2., gain: float = 1.): ...@@ -206,6 +213,7 @@ def xavier_normal_(scale: float = 2., gain: float = 1.):
scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0. scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0.
gain (float, optional): an optional scaling factor. Defaults 1.0. gain (float, optional): an optional scaling factor. Defaults 1.0.
""" """
# adapted from torch.nn.init # adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.' assert fan_in is not None, 'Fan_in is not provided.'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment