"...git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "4d582893a79b72a878e8fac52b3282799e404636"
Unverified Commit 2dd4d556 authored by Ofey Chan's avatar Ofey Chan Committed by GitHub
Browse files

[NFC] polish colossalai/nn/init.py code style (#1292)

parent 556b9b7e
......@@ -7,6 +7,7 @@ import torch.nn as nn
def zeros_():
"""Return the initializer filling the input Tensor with the scalar zeros"""
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.zeros_(tensor)
......@@ -15,6 +16,7 @@ def zeros_():
def ones_():
"""Return the initializer filling the input Tensor with the scalar ones"""
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.ones_(tensor)
......@@ -46,6 +48,7 @@ def normal_(mean: float = 0., std: float = 1.):
mean (float): the mean of the normal distribution. Defaults 0.0.
std (float): the standard deviation of the normal distribution. Defaults 1.0.
"""
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.normal_(tensor, mean, std)
......@@ -66,6 +69,7 @@ def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float =
a (float): the minimum cutoff value. Defaults -2.0.
b (float): the maximum cutoff value. Defaults 2.0.
"""
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.trunc_normal_(tensor, mean, std, a, b)
......@@ -93,6 +97,7 @@ def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
nonlinearity (str, optional): the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
if 0 in tensor.shape:
......@@ -136,6 +141,7 @@ def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
nonlinearity (str, optional): the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
if 0 in tensor.shape:
......@@ -175,6 +181,7 @@ def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1
scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0.
gain (float, optional): an optional scaling factor. Defaults 1.0.
"""
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.'
......@@ -206,6 +213,7 @@ def xavier_normal_(scale: float = 2., gain: float = 1.):
scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0.
gain (float, optional): an optional scaling factor. Defaults 1.0.
"""
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.'
......@@ -241,4 +249,4 @@ def lecun_normal_():
std = math.sqrt(1.0 / fan_in)
return nn.init.trunc_normal_(tensor, std=std / .87962566103423978)
return initializer
\ No newline at end of file
return initializer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment