Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
622f8632
Unverified
Commit
622f8632
authored
Dec 22, 2022
by
アマデウス
Committed by
GitHub
Dec 22, 2022
Browse files
[hotfix] Jit type hint #2161 (#2164)
parent
27327a4c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
2 deletions
+2
-2
colossalai/nn/layer/parallel_3d/_operation.py
colossalai/nn/layer/parallel_3d/_operation.py
+2
-2
No files found.
colossalai/nn/layer/parallel_3d/_operation.py
100644 → 100755
View file @
622f8632
...
@@ -281,7 +281,7 @@ def vocab_parallel_classifier_3d(
...
@@ -281,7 +281,7 @@ def vocab_parallel_classifier_3d(
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
norm_forward
(
x
,
mean
,
sqr_mean
,
weight
,
bias
,
eps
):
def
norm_forward
(
x
:
Tensor
,
mean
:
Tensor
,
sqr_mean
:
Tensor
,
weight
:
Tensor
,
bias
:
Tensor
,
eps
:
float
):
mu
=
x
-
mean
mu
=
x
-
mean
var
=
sqr_mean
-
mean
**
2
var
=
sqr_mean
-
mean
**
2
sigma
=
torch
.
sqrt
(
var
+
eps
)
sigma
=
torch
.
sqrt
(
var
+
eps
)
...
@@ -292,7 +292,7 @@ def norm_forward(x, mean, sqr_mean, weight, bias, eps):
...
@@ -292,7 +292,7 @@ def norm_forward(x, mean, sqr_mean, weight, bias, eps):
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
norm_backward
(
grad
,
mu
,
sigma
,
weight
):
def
norm_backward
(
grad
:
Tensor
,
mu
:
Tensor
,
sigma
:
Tensor
,
weight
:
Tensor
):
# dbias, dweight = grad, grad * mu / sigma
# dbias, dweight = grad, grad * mu / sigma
dz
=
grad
*
weight
dz
=
grad
*
weight
dmu
=
dz
/
sigma
dmu
=
dz
/
sigma
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment