Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fe8d1302
"pytorch_transformers/optimization.py" did not exist on "88874f6cf09e14fc482abc186adebb2767dca258"
Unverified
Commit
fe8d1302
authored
Dec 08, 2023
by
Charbel Abi Daher
Committed by
GitHub
Dec 08, 2023
Browse files
Added passing parameters to "reduce_lr_on_plateau" scheduler (#27860)
parent
56be5e80
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
6 deletions
+12
-6
src/transformers/optimization.py
src/transformers/optimization.py
+12
-6
No files found.
src/transformers/optimization.py
View file @
fe8d1302
...
...
@@ -53,19 +53,22 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
return
LambdaLR
(
optimizer
,
_get_constant_lambda
,
last_epoch
=
last_epoch
)
def
get_reduce_on_plateau_schedule
(
optimizer
:
Optimizer
):
def
get_reduce_on_plateau_schedule
(
optimizer
:
Optimizer
,
**
kwargs
):
"""
Create a schedule with a constant learning rate that decreases when a metric has stopped improving.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
kwargs (`dict`, *optional*):
Extra parameters to be passed to the scheduler. See `torch.optim.lr_scheduler.ReduceLROnPlateau`
for possible parameters.
Return:
`torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule.
"""
return
ReduceLROnPlateau
(
optimizer
)
return
ReduceLROnPlateau
(
optimizer
,
**
kwargs
)
def
_get_constant_schedule_with_warmup_lr_lambda
(
current_step
:
int
,
*
,
num_warmup_steps
:
int
):
...
...
@@ -359,9 +362,15 @@ def get_scheduler(
"""
name
=
SchedulerType
(
name
)
schedule_func
=
TYPE_TO_SCHEDULER_FUNCTION
[
name
]
if
name
==
SchedulerType
.
CONSTANT
or
name
==
SchedulerType
.
REDUCE_ON_PLATEAU
:
if
name
==
SchedulerType
.
CONSTANT
:
return
schedule_func
(
optimizer
)
if
scheduler_specific_kwargs
is
None
:
scheduler_specific_kwargs
=
{}
if
name
==
SchedulerType
.
REDUCE_ON_PLATEAU
:
return
schedule_func
(
optimizer
,
**
scheduler_specific_kwargs
)
# All other schedulers require `num_warmup_steps`
if
num_warmup_steps
is
None
:
raise
ValueError
(
f
"
{
name
}
requires `num_warmup_steps`, please provide that argument."
)
...
...
@@ -376,9 +385,6 @@ def get_scheduler(
if
num_training_steps
is
None
:
raise
ValueError
(
f
"
{
name
}
requires `num_training_steps`, please provide that argument."
)
if
scheduler_specific_kwargs
is
None
:
scheduler_specific_kwargs
=
{}
return
schedule_func
(
optimizer
,
num_warmup_steps
=
num_warmup_steps
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment