Commit 99495376 authored by Deyu Fu's avatar Deyu Fu
Browse files

make norm compatible with pytorch version <= 1.0.1

parent 94667417
......@@ -116,7 +116,11 @@ class FP16_Optimizer(object):
"""
# TODO: Not most efficient with copy to cpu and sync
# only support 2-norm now
norm = float(torch.norm(fp16_grads_flat, 2.0, dtype=torch.float32))
# for torch version <= 1.0.1, torch.norm with dtype will fail and fall back to cast
try:
norm = float(torch.norm(fp16_grads_flat, 2.0, dtype=torch.float32))
except TypeError as err:
norm = float(torch.norm(fp16_grads_flat.float(), 2.0))
if norm == float('inf') or norm == -float('inf') or norm != norm:
return -1
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment