Commit c52c9dbf authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

fixed clip_grads for when len(grads_for_norm) == 0.

parent dcf2ccc9
......@@ -75,12 +75,15 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
# Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel.
grad_norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_for_norm],
False # no per-parameter norm
)
if grads_for_norm:
grad_norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_for_norm],
False # no per-parameter norm
)
else:
grad_norm = torch.cuda.FloatTensor([0])
# Since we will be summing across data parallel groups,
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment