Commit ef63acd2 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

re-added grad dtype assertion in clip_grads.

parent 9044bc42
......@@ -56,7 +56,11 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
grads_for_norm = [grads_for_norm]
# Grads.
grads = [ p.grad.detach() for p in parameters if p.grad is not None ]
grads = []
for param in parameters:
if param.grad is not None:
assert param.grad.type() == 'torch.cuda.FloatTensor'
grads.append(param.grad.detach())
# Norm parameters.
max_norm = float(max_norm)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment