Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
73e9eb13
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "1949d3a88925de1e0c2a7a9a713e2b1dde0449c7"
Commit
73e9eb13
authored
Sep 08, 2022
by
binmakeswell
Committed by
Frank Lee
Sep 08, 2022
Browse files
[NFC] polish colossalai/nn/lr_scheduler/cosine.py code style
parent
318fbf11
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
14 deletions
+17
-14
colossalai/nn/lr_scheduler/cosine.py
colossalai/nn/lr_scheduler/cosine.py
+17
-14
No files found.
colossalai/nn/lr_scheduler/cosine.py
View file @
73e9eb13
...
...
@@ -62,8 +62,10 @@ class CosineAnnealingWarmupLR(WarmupScheduler):
"""
def
__init__
(
self
,
optimizer
,
total_steps
:
int
,
warmup_steps
:
int
=
0
,
eta_min
:
float
=
0.
,
last_epoch
:
int
=
-
1
):
base_scheduler
=
_CosineAnnealingLR
(
optimizer
,
total_steps
-
warmup_steps
,
eta_min
=
eta_min
,
last_epoch
=
last_epoch
)
base_scheduler
=
_CosineAnnealingLR
(
optimizer
,
total_steps
-
warmup_steps
,
eta_min
=
eta_min
,
last_epoch
=
last_epoch
)
super
().
__init__
(
optimizer
,
warmup_steps
,
base_scheduler
)
...
...
@@ -81,12 +83,10 @@ class FlatAnnealingLR(DelayerScheduler):
def
__init__
(
self
,
optimizer
,
total_steps
:
int
,
pct_start
:
float
=
0.72
,
last_epoch
:
int
=
-
1
,
**
kwargs
):
if
not
(
0.0
<=
pct_start
<=
1.0
):
raise
ValueError
(
f
'pct_start must >= 0.0 and <= 1.0, got
{
pct_start
}
'
)
raise
ValueError
(
f
'pct_start must >= 0.0 and <= 1.0, got
{
pct_start
}
'
)
flat_steps
=
int
(
total_steps
*
pct_start
)
anneal_steps
=
total_steps
-
flat_steps
base_scheduler
=
_CosineAnnealingLR
(
optimizer
,
anneal_steps
)
base_scheduler
=
_CosineAnnealingLR
(
optimizer
,
anneal_steps
)
super
().
__init__
(
optimizer
,
flat_steps
,
base_scheduler
,
last_epoch
=
last_epoch
)
...
...
@@ -105,14 +105,17 @@ class FlatAnnealingWarmupLR(WarmupDelayerScheduler):
the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr.
"""
def
__init__
(
self
,
optimizer
,
total_steps
:
int
,
warmup_steps
:
int
=
0
,
pct_start
:
float
=
0.72
,
eta_min
:
int
=
0
,
last_epoch
:
int
=
-
1
,
**
kwargs
):
def
__init__
(
self
,
optimizer
,
total_steps
:
int
,
warmup_steps
:
int
=
0
,
pct_start
:
float
=
0.72
,
eta_min
:
int
=
0
,
last_epoch
:
int
=
-
1
,
**
kwargs
):
if
not
(
0.0
<=
pct_start
<=
1.0
):
raise
ValueError
(
f
'pct_start must >= 0.0 and <= 1.0, got
{
pct_start
}
'
)
raise
ValueError
(
f
'pct_start must >= 0.0 and <= 1.0, got
{
pct_start
}
'
)
flat_steps
=
int
((
total_steps
-
warmup_steps
)
*
pct_start
)
anneal_steps
=
total_steps
-
warmup_steps
-
flat_steps
base_scheduler
=
_CosineAnnealingLR
(
optimizer
,
anneal_steps
,
eta_min
=
eta_min
)
super
().
__init__
(
optimizer
,
warmup_steps
,
flat_steps
,
base_scheduler
,
last_epoch
=
last_epoch
)
base_scheduler
=
_CosineAnnealingLR
(
optimizer
,
anneal_steps
,
eta_min
=
eta_min
)
super
().
__init__
(
optimizer
,
warmup_steps
,
flat_steps
,
base_scheduler
,
last_epoch
=
last_epoch
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment