Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cda9c82a
Unverified
Commit
cda9c82a
authored
May 30, 2024
by
zspo
Committed by
GitHub
May 30, 2024
Browse files
fix get_scheduler when name is warmup_stable_decay (#31128)
fix get_scheduler args
parent
5e5c4d62
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
0 deletions
+25
-0
src/transformers/optimization.py
src/transformers/optimization.py
+3
-0
tests/optimization/test_optimization.py
tests/optimization/test_optimization.py
+22
-0
No files found.
src/transformers/optimization.py
View file @
cda9c82a
...
@@ -540,6 +540,9 @@ def get_scheduler(
...
@@ -540,6 +540,9 @@ def get_scheduler(
if
name
==
SchedulerType
.
INVERSE_SQRT
:
if
name
==
SchedulerType
.
INVERSE_SQRT
:
return
schedule_func
(
optimizer
,
num_warmup_steps
=
num_warmup_steps
)
return
schedule_func
(
optimizer
,
num_warmup_steps
=
num_warmup_steps
)
if
name
==
SchedulerType
.
WARMUP_STABLE_DECAY
:
return
schedule_func
(
optimizer
,
num_warmup_steps
=
num_warmup_steps
,
**
scheduler_specific_kwargs
)
# All other schedulers require `num_training_steps`
# All other schedulers require `num_training_steps`
if
num_training_steps
is
None
:
if
num_training_steps
is
None
:
raise
ValueError
(
f
"
{
name
}
requires `num_training_steps`, please provide that argument."
)
raise
ValueError
(
f
"
{
name
}
requires `num_training_steps`, please provide that argument."
)
...
...
tests/optimization/test_optimization.py
View file @
cda9c82a
...
@@ -36,6 +36,7 @@ if is_torch_available():
...
@@ -36,6 +36,7 @@ if is_torch_available():
get_inverse_sqrt_schedule
,
get_inverse_sqrt_schedule
,
get_linear_schedule_with_warmup
,
get_linear_schedule_with_warmup
,
get_polynomial_decay_schedule_with_warmup
,
get_polynomial_decay_schedule_with_warmup
,
get_scheduler
,
get_wsd_schedule
,
get_wsd_schedule
,
)
)
...
@@ -176,6 +177,27 @@ class ScheduleInitTest(unittest.TestCase):
...
@@ -176,6 +177,27 @@ class ScheduleInitTest(unittest.TestCase):
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
self
.
assertListEqual
(
lrs_1
,
lrs_2
,
msg
=
f
"failed for
{
scheduler_func
}
in save and reload"
)
self
.
assertListEqual
(
lrs_1
,
lrs_2
,
msg
=
f
"failed for
{
scheduler_func
}
in save and reload"
)
def
test_get_scheduler
(
self
):
test_params
=
[
{
"name"
:
"warmup_stable_decay"
,
"optimizer"
:
self
.
optimizer
,
"num_warmup_steps"
:
2
,
"scheduler_specific_kwargs"
:
{
"num_stable_steps"
:
1
,
"num_decay_steps"
:
3
},
},
{
"name"
:
"warmup_stable_decay"
,
"optimizer"
:
self
.
optimizer
,
"num_warmup_steps"
:
2
,
"num_training_steps"
:
10
,
"scheduler_specific_kwargs"
:
{
"num_stable_steps"
:
1
,
"num_decay_steps"
:
3
},
},
{
"name"
:
"cosine"
,
"optimizer"
:
self
.
optimizer
,
"num_warmup_steps"
:
2
,
"num_training_steps"
:
10
},
]
for
param
in
test_params
:
self
.
assertTrue
(
get_scheduler
(
**
param
),
msg
=
f
"failed for
{
param
[
'name'
]
}
in get_scheduler"
)
class
LambdaScheduleWrapper
:
class
LambdaScheduleWrapper
:
"""See https://github.com/huggingface/transformers/issues/21689"""
"""See https://github.com/huggingface/transformers/issues/21689"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment