Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8aa8513f
Unverified
Commit
8aa8513f
authored
May 19, 2023
by
Jiewen Tan
Committed by
GitHub
May 19, 2023
Browse files
Remove .data usages in optimizations.py (#23417)
Patched the optimizers
parent
3cf01b20
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
10 deletions
+12
-10
src/transformers/optimization.py
src/transformers/optimization.py
+12
-10
No files found.
src/transformers/optimization.py
View file @
8aa8513f
...
...
@@ -422,6 +422,7 @@ class AdamW(Optimizer):
defaults
=
{
"lr"
:
lr
,
"betas"
:
betas
,
"eps"
:
eps
,
"weight_decay"
:
weight_decay
,
"correct_bias"
:
correct_bias
}
super
().
__init__
(
params
,
defaults
)
@
torch
.
no_grad
()
def
step
(
self
,
closure
:
Callable
=
None
):
"""
Performs a single optimization step.
...
...
@@ -437,7 +438,7 @@ class AdamW(Optimizer):
for
p
in
group
[
"params"
]:
if
p
.
grad
is
None
:
continue
grad
=
p
.
grad
.
data
grad
=
p
.
grad
if
grad
.
is_sparse
:
raise
RuntimeError
(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
...
...
@@ -447,9 +448,9 @@ class AdamW(Optimizer):
if
len
(
state
)
==
0
:
state
[
"step"
]
=
0
# Exponential moving average of gradient values
state
[
"exp_avg"
]
=
torch
.
zeros_like
(
p
.
data
)
state
[
"exp_avg"
]
=
torch
.
zeros_like
(
p
)
# Exponential moving average of squared gradient values
state
[
"exp_avg_sq"
]
=
torch
.
zeros_like
(
p
.
data
)
state
[
"exp_avg_sq"
]
=
torch
.
zeros_like
(
p
)
exp_avg
,
exp_avg_sq
=
state
[
"exp_avg"
],
state
[
"exp_avg_sq"
]
beta1
,
beta2
=
group
[
"betas"
]
...
...
@@ -468,7 +469,7 @@ class AdamW(Optimizer):
bias_correction2
=
1.0
-
beta2
**
state
[
"step"
]
step_size
=
step_size
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
p
.
data
.
addcdiv_
(
exp_avg
,
denom
,
value
=-
step_size
)
p
.
addcdiv_
(
exp_avg
,
denom
,
value
=-
step_size
)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
...
...
@@ -479,7 +480,7 @@ class AdamW(Optimizer):
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if
group
[
"weight_decay"
]
>
0.0
:
p
.
data
.
add_
(
p
.
data
,
alpha
=
(
-
group
[
"lr"
]
*
group
[
"weight_decay"
]))
p
.
add_
(
p
,
alpha
=
(
-
group
[
"lr"
]
*
group
[
"weight_decay"
]))
return
loss
...
...
@@ -630,6 +631,7 @@ class Adafactor(Optimizer):
c_factor
=
exp_avg_sq_col
.
unsqueeze
(
-
2
).
rsqrt
()
return
torch
.
mul
(
r_factor
,
c_factor
)
@
torch
.
no_grad
()
def
step
(
self
,
closure
=
None
):
"""
Performs a single optimization step
...
...
@@ -646,7 +648,7 @@ class Adafactor(Optimizer):
for
p
in
group
[
"params"
]:
if
p
.
grad
is
None
:
continue
grad
=
p
.
grad
.
data
grad
=
p
.
grad
if
grad
.
dtype
in
{
torch
.
float16
,
torch
.
bfloat16
}:
grad
=
grad
.
float
()
if
grad
.
is_sparse
:
...
...
@@ -679,8 +681,8 @@ class Adafactor(Optimizer):
else
:
state
[
"exp_avg_sq"
]
=
state
[
"exp_avg_sq"
].
to
(
grad
)
p_data_fp32
=
p
.
data
if
p
.
data
.
dtype
in
{
torch
.
float16
,
torch
.
bfloat16
}:
p_data_fp32
=
p
if
p
.
dtype
in
{
torch
.
float16
,
torch
.
bfloat16
}:
p_data_fp32
=
p_data_fp32
.
float
()
state
[
"step"
]
+=
1
...
...
@@ -718,8 +720,8 @@ class Adafactor(Optimizer):
p_data_fp32
.
add_
(
-
update
)
if
p
.
data
.
dtype
in
{
torch
.
float16
,
torch
.
bfloat16
}:
p
.
data
.
copy_
(
p_data_fp32
)
if
p
.
dtype
in
{
torch
.
float16
,
torch
.
bfloat16
}:
p
.
copy_
(
p_data_fp32
)
return
loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment