Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
96c4daab
Unverified
Commit
96c4daab
authored
Jun 12, 2020
by
Chunyang Wen
Committed by
GitHub
Jun 11, 2020
Browse files
minor refactor loss scaler (#261)
parent
f5025506
Pipeline
#203
failed with stages
in 0 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
31 deletions
+28
-31
deepspeed/pt/loss_scaler.py
deepspeed/pt/loss_scaler.py
+28
-31
No files found.
deepspeed/pt/loss_scaler.py
View file @
96c4daab
...
@@ -31,7 +31,29 @@ def to_python_float(t):
...
@@ -31,7 +31,29 @@ def to_python_float(t):
return
t
[
0
]
return
t
[
0
]
class
LossScaler
:
class
LossScalerBase
:
"""LossScalarBase
Base class for a loss scaler
"""
def
__init__
(
self
,
cur_scale
):
self
.
cur_scale
=
cur_scale
@
property
def
loss_scale
(
self
):
return
self
.
cur_scale
def
scale_gradient
(
self
,
module
,
grad_in
,
grad_out
):
return
tuple
(
self
.
loss_scale
*
g
for
g
in
grad_in
)
def
update_scale
(
self
,
overflow
):
pass
def
backward
(
self
,
loss
,
retain_graph
=
False
):
scaled_loss
=
loss
*
self
.
loss_scale
scaled_loss
.
backward
(
retain_graph
=
retain_graph
)
class
LossScaler
(
LossScalerBase
):
"""
"""
Class that manages a static loss scale. This class is intended to interact with
Class that manages a static loss scale. This class is intended to interact with
:class:`FP16_Optimizer`, and should not be directly manipulated by the user.
:class:`FP16_Optimizer`, and should not be directly manipulated by the user.
...
@@ -43,7 +65,7 @@ class LossScaler:
...
@@ -43,7 +65,7 @@ class LossScaler:
scale (float, optional, default=1.0): The loss scale.
scale (float, optional, default=1.0): The loss scale.
"""
"""
def
__init__
(
self
,
scale
=
1
):
def
__init__
(
self
,
scale
=
1
):
s
elf
.
cur_scale
=
scale
s
uper
(
LossScaler
,
self
).
__init__
(
scale
)
# `params` is a list / generator of torch.Variable
# `params` is a list / generator of torch.Variable
def
has_overflow
(
self
,
params
):
def
has_overflow
(
self
,
params
):
...
@@ -53,22 +75,8 @@ class LossScaler:
...
@@ -53,22 +75,8 @@ class LossScaler:
def
_has_inf_or_nan
(
x
):
def
_has_inf_or_nan
(
x
):
return
False
return
False
def
update_scale
(
self
,
overflow
):
pass
@
property
def
loss_scale
(
self
):
return
self
.
cur_scale
def
scale_gradient
(
self
,
module
,
grad_in
,
grad_out
):
return
tuple
(
self
.
loss_scale
*
g
for
g
in
grad_in
)
def
backward
(
self
,
loss
,
retain_graph
=
False
):
class
DynamicLossScaler
(
LossScalerBase
):
scaled_loss
=
loss
*
self
.
loss_scale
scaled_loss
.
backward
(
retain_graph
=
retain_graph
)
class
DynamicLossScaler
:
"""
"""
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
...
@@ -100,7 +108,7 @@ class DynamicLossScaler:
...
@@ -100,7 +108,7 @@ class DynamicLossScaler:
min_scale
=
1
,
min_scale
=
1
,
delayed_shift
=
1
,
delayed_shift
=
1
,
consecutive_hysteresis
=
False
):
consecutive_hysteresis
=
False
):
s
elf
.
cur_scale
=
init_scale
s
uper
(
DynamicLossScaler
,
self
).
__init__
(
init_scale
)
self
.
cur_iter
=
0
self
.
cur_iter
=
0
self
.
last_overflow_iter
=
-
1
self
.
last_overflow_iter
=
-
1
self
.
scale_factor
=
scale_factor
self
.
scale_factor
=
scale_factor
...
@@ -113,7 +121,7 @@ class DynamicLossScaler:
...
@@ -113,7 +121,7 @@ class DynamicLossScaler:
# `params` is a list / generator of torch.Variable
# `params` is a list / generator of torch.Variable
def
has_overflow_serial
(
self
,
params
):
def
has_overflow_serial
(
self
,
params
):
for
p
in
params
:
for
p
in
params
:
if
p
.
grad
is
not
None
and
DynamicLossScaler
.
_has_inf_or_nan
(
p
.
grad
.
data
):
if
p
.
grad
is
not
None
and
self
.
_has_inf_or_nan
(
p
.
grad
.
data
):
return
True
return
True
return
False
return
False
...
@@ -135,7 +143,7 @@ class DynamicLossScaler:
...
@@ -135,7 +143,7 @@ class DynamicLossScaler:
raise
raise
return
True
return
True
else
:
else
:
if
cpu_sum
==
float
(
'inf'
)
or
cpu_sum
==
-
float
(
'inf'
)
or
cpu_sum
!=
cpu_sum
:
if
cpu_sum
in
[
float
(
'inf'
)
,
-
float
(
'inf'
)
]
or
cpu_sum
!=
cpu_sum
:
return
True
return
True
return
False
return
False
...
@@ -157,17 +165,6 @@ class DynamicLossScaler:
...
@@ -157,17 +165,6 @@ class DynamicLossScaler:
self
.
cur_scale
*=
self
.
scale_factor
self
.
cur_scale
*=
self
.
scale_factor
self
.
cur_iter
+=
1
self
.
cur_iter
+=
1
@
property
def
loss_scale
(
self
):
return
self
.
cur_scale
def
scale_gradient
(
self
,
module
,
grad_in
,
grad_out
):
return
tuple
(
self
.
loss_scale
*
g
for
g
in
grad_in
)
def
backward
(
self
,
loss
,
retain_graph
=
False
):
scaled_loss
=
loss
*
self
.
loss_scale
scaled_loss
.
backward
(
retain_graph
=
retain_graph
)
##############################################################
##############################################################
# Example usage below here -- assuming it's in a separate file
# Example usage below here -- assuming it's in a separate file
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment