Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
b7f5cb78
"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "0be8d8a7f6a86384c1c4a42a645669a37f9c01f1"
Unverified
Commit
b7f5cb78
authored
Apr 21, 2020
by
marload
Committed by
GitHub
Apr 20, 2020
Browse files
Early Return Pattern "if return else return" -> "if return return" (#197)
parent
675d73e0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
12 deletions
+7
-12
deepspeed/pt/deepspeed_lr_schedules.py
deepspeed/pt/deepspeed_lr_schedules.py
+6
-10
deepspeed/pt/loss_scaler.py
deepspeed/pt/loss_scaler.py
+1
-2
No files found.
deepspeed/pt/deepspeed_lr_schedules.py
View file @
b7f5cb78
...
@@ -271,11 +271,10 @@ def get_lr_from_config(config):
...
@@ -271,11 +271,10 @@ def get_lr_from_config(config):
if
lr_schedule
==
LR_RANGE_TEST
:
if
lr_schedule
==
LR_RANGE_TEST
:
return
lr_params
[
LR_RANGE_TEST_MIN_LR
],
''
return
lr_params
[
LR_RANGE_TEST_MIN_LR
],
''
el
if
lr_schedule
==
ONE_CYCLE
:
if
lr_schedule
==
ONE_CYCLE
:
return
lr_params
[
CYCLE_MAX_LR
],
''
return
lr_params
[
CYCLE_MAX_LR
],
''
else
:
# Warmup LR
# Warmup LR
return
lr_params
[
WARMUP_MAX_LR
],
''
return
lr_params
[
WARMUP_MAX_LR
],
''
"""
"""
...
@@ -624,8 +623,7 @@ class OneCycle(object):
...
@@ -624,8 +623,7 @@ class OneCycle(object):
"""
"""
if
self
.
last_batch_iteration
<=
self
.
total_size
:
if
self
.
last_batch_iteration
<=
self
.
total_size
:
return
self
.
_get_cycle_lr
()
return
self
.
_get_cycle_lr
()
else
:
return
self
.
_get_decay_lr
(
self
.
last_batch_iteration
-
self
.
total_size
)
return
self
.
_get_decay_lr
(
self
.
last_batch_iteration
-
self
.
total_size
)
def
step
(
self
,
batch_iteration
=
None
):
def
step
(
self
,
batch_iteration
=
None
):
if
batch_iteration
is
None
:
if
batch_iteration
is
None
:
...
@@ -701,8 +699,7 @@ class WarmupLR(object):
...
@@ -701,8 +699,7 @@ class WarmupLR(object):
def
_get_gamma
(
self
):
def
_get_gamma
(
self
):
if
self
.
last_batch_iteration
<
self
.
warmup_num_steps
:
if
self
.
last_batch_iteration
<
self
.
warmup_num_steps
:
return
self
.
inverse_log_warm_up
*
math
.
log
(
self
.
last_batch_iteration
+
1
)
return
self
.
inverse_log_warm_up
*
math
.
log
(
self
.
last_batch_iteration
+
1
)
else
:
return
1.0
return
1.0
def
_format_param
(
self
,
optimizer
,
param_value
,
param_name
):
def
_format_param
(
self
,
optimizer
,
param_value
,
param_name
):
if
isinstance
(
param_value
,
list
)
or
isinstance
(
param_value
,
tuple
):
if
isinstance
(
param_value
,
list
)
or
isinstance
(
param_value
,
tuple
):
...
@@ -712,5 +709,4 @@ class WarmupLR(object):
...
@@ -712,5 +709,4 @@ class WarmupLR(object):
param_name
,
param_name
,
FileNotFoundError
(
param_value
)))
FileNotFoundError
(
param_value
)))
return
list
(
param_value
)
return
list
(
param_value
)
else
:
return
[
param_value
]
*
len
(
optimizer
.
param_groups
)
return
[
param_value
]
*
len
(
optimizer
.
param_groups
)
deepspeed/pt/loss_scaler.py
View file @
b7f5cb78
...
@@ -23,8 +23,7 @@ import torch
...
@@ -23,8 +23,7 @@ import torch
def
to_python_float
(
t
):
def
to_python_float
(
t
):
if
hasattr
(
t
,
'item'
):
if
hasattr
(
t
,
'item'
):
return
t
.
item
()
return
t
.
item
()
else
:
return
t
[
0
]
return
t
[
0
]
class
LossScaler
:
class
LossScaler
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment