Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
96c4daab
"...text-generation-inference.git" did not exist on "c38a7d7ddd9c612e368adec1ef94583be602fc7e"
Unverified
Commit
96c4daab
authored
Jun 12, 2020
by
Chunyang Wen
Committed by
GitHub
Jun 11, 2020
Browse files
minor refactor loss scaler (#261)
parent
f5025506
Pipeline
#203
failed with stages
in 0 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
31 deletions
+28
-31
deepspeed/pt/loss_scaler.py
deepspeed/pt/loss_scaler.py
+28
-31
No files found.
deepspeed/pt/loss_scaler.py
View file @
96c4daab
...
...
@@ -31,7 +31,29 @@ def to_python_float(t):
return
t
[
0
]
class
LossScaler
:
class
LossScalerBase
:
"""LossScalarBase
Base class for a loss scaler
"""
def
__init__
(
self
,
cur_scale
):
self
.
cur_scale
=
cur_scale
@
property
def
loss_scale
(
self
):
return
self
.
cur_scale
def
scale_gradient
(
self
,
module
,
grad_in
,
grad_out
):
return
tuple
(
self
.
loss_scale
*
g
for
g
in
grad_in
)
def
update_scale
(
self
,
overflow
):
pass
def
backward
(
self
,
loss
,
retain_graph
=
False
):
scaled_loss
=
loss
*
self
.
loss_scale
scaled_loss
.
backward
(
retain_graph
=
retain_graph
)
class
LossScaler
(
LossScalerBase
):
"""
Class that manages a static loss scale. This class is intended to interact with
:class:`FP16_Optimizer`, and should not be directly manipulated by the user.
...
...
@@ -43,7 +65,7 @@ class LossScaler:
scale (float, optional, default=1.0): The loss scale.
"""
def
__init__
(
self
,
scale
=
1
):
s
elf
.
cur_scale
=
scale
s
uper
(
LossScaler
,
self
).
__init__
(
scale
)
# `params` is a list / generator of torch.Variable
def
has_overflow
(
self
,
params
):
...
...
@@ -53,22 +75,8 @@ class LossScaler:
def
_has_inf_or_nan
(
x
):
return
False
def
update_scale
(
self
,
overflow
):
pass
@
property
def
loss_scale
(
self
):
return
self
.
cur_scale
def
scale_gradient
(
self
,
module
,
grad_in
,
grad_out
):
return
tuple
(
self
.
loss_scale
*
g
for
g
in
grad_in
)
def
backward
(
self
,
loss
,
retain_graph
=
False
):
scaled_loss
=
loss
*
self
.
loss_scale
scaled_loss
.
backward
(
retain_graph
=
retain_graph
)
class
DynamicLossScaler
:
class
DynamicLossScaler
(
LossScalerBase
):
"""
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
...
...
@@ -100,7 +108,7 @@ class DynamicLossScaler:
min_scale
=
1
,
delayed_shift
=
1
,
consecutive_hysteresis
=
False
):
s
elf
.
cur_scale
=
init_scale
s
uper
(
DynamicLossScaler
,
self
).
__init__
(
init_scale
)
self
.
cur_iter
=
0
self
.
last_overflow_iter
=
-
1
self
.
scale_factor
=
scale_factor
...
...
@@ -113,7 +121,7 @@ class DynamicLossScaler:
# `params` is a list / generator of torch.Variable
def
has_overflow_serial
(
self
,
params
):
for
p
in
params
:
if
p
.
grad
is
not
None
and
DynamicLossScaler
.
_has_inf_or_nan
(
p
.
grad
.
data
):
if
p
.
grad
is
not
None
and
self
.
_has_inf_or_nan
(
p
.
grad
.
data
):
return
True
return
False
...
...
@@ -135,7 +143,7 @@ class DynamicLossScaler:
raise
return
True
else
:
if
cpu_sum
==
float
(
'inf'
)
or
cpu_sum
==
-
float
(
'inf'
)
or
cpu_sum
!=
cpu_sum
:
if
cpu_sum
in
[
float
(
'inf'
)
,
-
float
(
'inf'
)
]
or
cpu_sum
!=
cpu_sum
:
return
True
return
False
...
...
@@ -157,17 +165,6 @@ class DynamicLossScaler:
self
.
cur_scale
*=
self
.
scale_factor
self
.
cur_iter
+=
1
@
property
def
loss_scale
(
self
):
return
self
.
cur_scale
def
scale_gradient
(
self
,
module
,
grad_in
,
grad_out
):
return
tuple
(
self
.
loss_scale
*
g
for
g
in
grad_in
)
def
backward
(
self
,
loss
,
retain_graph
=
False
):
scaled_loss
=
loss
*
self
.
loss_scale
scaled_loss
.
backward
(
retain_graph
=
retain_graph
)
##############################################################
# Example usage below here -- assuming it's in a separate file
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment