Commit 94667417 authored by Deyu Fu's avatar Deyu Fu
Browse files

change to use now great torch.norm

parent 1603407b
import torch import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import ctypes
stashed_err = None
try:
lib = ctypes.cdll.LoadLibrary(None)
lib.THCudaHalfTensor_normall.argtypes=[ctypes.c_void_p, ctypes.c_void_p]
lib.THCudaHalfTensor_normall.restype = ctypes.c_float
def fused_norm(input):
if input.type() == 'torch.cuda.HalfTensor':
# 16384 is half 2 if you stare at it long enough
return lib.THCudaHalfTensor_normall(torch.cuda._state_cdata,
input._cdata, 16384)
else:
return input.norm()
except TypeError as err:
stashed_err = err
def fused_norm(input):
raise RuntimeError("Failed to create fused_norm. This may happen on Windows "
"because of lib = ctypes.cdll.LoadLibrary(None): you can't "
"LoadLibrary with None. Original exception message was ",
stashed_err)
class FP16_Optimizer(object): class FP16_Optimizer(object):
""" """
...@@ -137,9 +114,9 @@ class FP16_Optimizer(object): ...@@ -137,9 +114,9 @@ class FP16_Optimizer(object):
Total norm of the current fp16 gradients (viewed as a single vector). Total norm of the current fp16 gradients (viewed as a single vector).
Returns -1 if the most recently computed fp16 gradients overflowed Returns -1 if the most recently computed fp16 gradients overflowed
""" """
# TODO: currently using pre-1.0 api, and not most efficient with copy to cpu and sync # TODO: Not most efficient with copy to cpu and sync
# only support 2-norm now # only support 2-norm now
norm = float(fused_norm(fp16_grads_flat)) norm = float(torch.norm(fp16_grads_flat, 2.0, dtype=torch.float32))
if norm == float('inf') or norm == -float('inf') or norm != norm: if norm == float('inf') or norm == -float('inf') or norm != norm:
return -1 return -1
else: else:
...@@ -224,7 +201,7 @@ class FP16_Optimizer(object): ...@@ -224,7 +201,7 @@ class FP16_Optimizer(object):
self.optimizer.param_groups = value self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups) param_groups = property(_get_param_groups, _set_param_groups)
def state_dict(self): def state_dict(self):
""" """
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
...@@ -250,9 +227,9 @@ class FP16_Optimizer(object): ...@@ -250,9 +227,9 @@ class FP16_Optimizer(object):
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
""" """
Loads a state_dict created by an earlier call to state_dict(). Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called. ``fp16_optimizer_instance.load_state_dict()`` is called.
Example:: Example::
...@@ -289,4 +266,3 @@ class FP16_Optimizer(object): ...@@ -289,4 +266,3 @@ class FP16_Optimizer(object):
# are guaranteed to exist, so we can just copy_() from the saved master params. # are guaranteed to exist, so we can just copy_() from the saved master params.
for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']): for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']):
current.data.copy_(saved.data) current.data.copy_(saved.data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment