Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
94667417
Commit
94667417
authored
Mar 06, 2019
by
Deyu Fu
Browse files
change to use now great torch.norm
parent
1603407b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
30 deletions
+6
-30
apex/optimizers/fp16_optimizer.py
apex/optimizers/fp16_optimizer.py
+6
-30
No files found.
apex/optimizers/fp16_optimizer.py
View file @
94667417
import
torch
import
torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
import
ctypes
stashed_err
=
None
try
:
lib
=
ctypes
.
cdll
.
LoadLibrary
(
None
)
lib
.
THCudaHalfTensor_normall
.
argtypes
=
[
ctypes
.
c_void_p
,
ctypes
.
c_void_p
]
lib
.
THCudaHalfTensor_normall
.
restype
=
ctypes
.
c_float
def
fused_norm
(
input
):
if
input
.
type
()
==
'torch.cuda.HalfTensor'
:
# 16384 is half 2 if you stare at it long enough
return
lib
.
THCudaHalfTensor_normall
(
torch
.
cuda
.
_state_cdata
,
input
.
_cdata
,
16384
)
else
:
return
input
.
norm
()
except
TypeError
as
err
:
stashed_err
=
err
def
fused_norm
(
input
):
raise
RuntimeError
(
"Failed to create fused_norm. This may happen on Windows "
"because of lib = ctypes.cdll.LoadLibrary(None): you can't "
"LoadLibrary with None. Original exception message was "
,
stashed_err
)
class
FP16_Optimizer
(
object
):
class
FP16_Optimizer
(
object
):
"""
"""
...
@@ -137,9 +114,9 @@ class FP16_Optimizer(object):
...
@@ -137,9 +114,9 @@ class FP16_Optimizer(object):
Total norm of the current fp16 gradients (viewed as a single vector).
Total norm of the current fp16 gradients (viewed as a single vector).
Returns -1 if the most recently computed fp16 gradients overflowed
Returns -1 if the most recently computed fp16 gradients overflowed
"""
"""
# TODO:
currently using pre-1.0 api, and n
ot most efficient with copy to cpu and sync
# TODO:
N
ot most efficient with copy to cpu and sync
# only support 2-norm now
# only support 2-norm now
norm
=
float
(
fused_
norm
(
fp16_grads_flat
))
norm
=
float
(
torch
.
norm
(
fp16_grads_flat
,
2.0
,
dtype
=
torch
.
float32
))
if
norm
==
float
(
'inf'
)
or
norm
==
-
float
(
'inf'
)
or
norm
!=
norm
:
if
norm
==
float
(
'inf'
)
or
norm
==
-
float
(
'inf'
)
or
norm
!=
norm
:
return
-
1
return
-
1
else
:
else
:
...
@@ -224,7 +201,7 @@ class FP16_Optimizer(object):
...
@@ -224,7 +201,7 @@ class FP16_Optimizer(object):
self
.
optimizer
.
param_groups
=
value
self
.
optimizer
.
param_groups
=
value
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
def
state_dict
(
self
):
def
state_dict
(
self
):
"""
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
...
@@ -250,9 +227,9 @@ class FP16_Optimizer(object):
...
@@ -250,9 +227,9 @@ class FP16_Optimizer(object):
def
load_state_dict
(
self
,
state_dict
):
def
load_state_dict
(
self
,
state_dict
):
"""
"""
Loads a state_dict created by an earlier call to state_dict().
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
Example::
...
@@ -289,4 +266,3 @@ class FP16_Optimizer(object):
...
@@ -289,4 +266,3 @@ class FP16_Optimizer(object):
# are guaranteed to exist, so we can just copy_() from the saved master params.
# are guaranteed to exist, so we can just copy_() from the saved master params.
for
current
,
saved
in
zip
(
self
.
fp32_groups_flat
,
state_dict
[
'fp32_groups_flat'
]):
for
current
,
saved
in
zip
(
self
.
fp32_groups_flat
,
state_dict
[
'fp32_groups_flat'
]):
current
.
data
.
copy_
(
saved
.
data
)
current
.
data
.
copy_
(
saved
.
data
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment