Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2dd4d556
"...git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "4d582893a79b72a878e8fac52b3282799e404636"
Unverified
Commit
2dd4d556
authored
Jul 13, 2022
by
Ofey Chan
Committed by
GitHub
Jul 13, 2022
Browse files
[NFC] polish colossalai/nn/init.py code style (#1292)
parent
556b9b7e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
1 deletion
+9
-1
colossalai/nn/init.py
colossalai/nn/init.py
+9
-1
No files found.
colossalai/nn/init.py
View file @
2dd4d556
...
...
@@ -7,6 +7,7 @@ import torch.nn as nn
def
zeros_
():
"""Return the initializer filling the input Tensor with the scalar zeros"""
def
initializer
(
tensor
:
Tensor
,
fan_in
:
int
=
None
,
fan_out
:
int
=
None
):
return
nn
.
init
.
zeros_
(
tensor
)
...
...
@@ -15,6 +16,7 @@ def zeros_():
def
ones_
():
"""Return the initializer filling the input Tensor with the scalar ones"""
def
initializer
(
tensor
:
Tensor
,
fan_in
:
int
=
None
,
fan_out
:
int
=
None
):
return
nn
.
init
.
ones_
(
tensor
)
...
...
@@ -46,6 +48,7 @@ def normal_(mean: float = 0., std: float = 1.):
mean (float): the mean of the normal distribution. Defaults 0.0.
std (float): the standard deviation of the normal distribution. Defaults 1.0.
"""
def
initializer
(
tensor
:
Tensor
,
fan_in
:
int
=
None
,
fan_out
:
int
=
None
):
return
nn
.
init
.
normal_
(
tensor
,
mean
,
std
)
...
...
@@ -66,6 +69,7 @@ def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float =
a (float): the minimum cutoff value. Defaults -2.0.
b (float): the maximum cutoff value. Defaults 2.0.
"""
def
initializer
(
tensor
:
Tensor
,
fan_in
:
int
=
None
,
fan_out
:
int
=
None
):
return
nn
.
init
.
trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
)
...
...
@@ -93,6 +97,7 @@ def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
nonlinearity (str, optional): the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
# adapted from torch.nn.init
def
initializer
(
tensor
:
Tensor
,
fan_in
:
int
=
None
,
fan_out
:
int
=
None
):
if
0
in
tensor
.
shape
:
...
...
@@ -136,6 +141,7 @@ def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
nonlinearity (str, optional): the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
# adapted from torch.nn.init
def
initializer
(
tensor
:
Tensor
,
fan_in
:
int
=
None
,
fan_out
:
int
=
None
):
if
0
in
tensor
.
shape
:
...
...
@@ -175,6 +181,7 @@ def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1
scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0.
gain (float, optional): an optional scaling factor. Defaults 1.0.
"""
# adapted from torch.nn.init
def
initializer
(
tensor
:
Tensor
,
fan_in
:
int
=
None
,
fan_out
:
int
=
None
):
assert
fan_in
is
not
None
,
'Fan_in is not provided.'
...
...
@@ -206,6 +213,7 @@ def xavier_normal_(scale: float = 2., gain: float = 1.):
scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0.
gain (float, optional): an optional scaling factor. Defaults 1.0.
"""
# adapted from torch.nn.init
def
initializer
(
tensor
:
Tensor
,
fan_in
:
int
=
None
,
fan_out
:
int
=
None
):
assert
fan_in
is
not
None
,
'Fan_in is not provided.'
...
...
@@ -241,4 +249,4 @@ def lecun_normal_():
std
=
math
.
sqrt
(
1.0
/
fan_in
)
return
nn
.
init
.
trunc_normal_
(
tensor
,
std
=
std
/
.
87962566103423978
)
return
initializer
\ No newline at end of file
return
initializer
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment