Unverified Commit 3ccdd63d authored by Chaitanya Sri Krishna Lolla's avatar Chaitanya Sri Krishna Lolla Committed by GitHub
Browse files

enable python only base sparse tensor support for loss scaling (#2)

parent e85a1d4b
......@@ -6,11 +6,17 @@ from itertools import product
def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False):
# Exception handling for 18.04 compatibility
if check_overflow:
if model_grad.is_sparse:
cpu_sum = float(model_grad.float()._values().sum())
else:
cpu_sum = float(model_grad.float().sum())
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
if master_grad is not model_grad: # copy_ probably internally short-circuits this
if model_grad.is_sparse:
master_grad.copy_(model_grad.to_dense())
else:
master_grad.copy_(model_grad)
if scale != 1.0:
master_grad.mul_(scale)
......@@ -19,6 +25,9 @@ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=F
def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False):
# Exception handling for 18.04 compatibility
if check_overflow:
if model_grad.is_sparse:
cpu_sum = float(model_grad.float()._values().sum())
else:
cpu_sum = float(model_grad.float().sum())
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment