Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
91f3dfbf
Unverified
Commit
91f3dfbf
authored
Dec 12, 2021
by
Patrick von Platen
Committed by
GitHub
Dec 12, 2021
Browse files
[Adafactor] Fix adafactor (#14713)
* correct changes * add comment
parent
86dd23bb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
3 deletions
+5
-3
src/transformers/optimization.py
src/transformers/optimization.py
+5
-3
No files found.
src/transformers/optimization.py
View file @
91f3dfbf
...
...
@@ -503,9 +503,11 @@ class Adafactor(Optimizer):
@
staticmethod
def
_approx_sq_grad
(
exp_avg_sq_row
,
exp_avg_sq_col
):
r_factor
=
(
exp_avg_sq_row
/
exp_avg_sq_row
.
mean
(
dim
=-
1
,
keepdim
=
True
)).
rsqrt_
()
c_factor
=
exp_avg_sq_col
.
rsqrt
()
return
torch
.
mm
(
r_factor
.
unsqueeze
(
-
1
),
c_factor
.
unsqueeze
(
0
))
# copy from fairseq's adafactor implementation:
# https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
r_factor
=
(
exp_avg_sq_row
/
exp_avg_sq_row
.
mean
(
dim
=-
1
,
keepdim
=
True
)).
rsqrt_
().
unsqueeze
(
-
1
)
c_factor
=
exp_avg_sq_col
.
unsqueeze
(
-
2
).
rsqrt
()
return
torch
.
mul
(
r_factor
,
c_factor
)
def
step
(
self
,
closure
=
None
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment