Commit 57437cb1 authored by Rewon Child's avatar Rewon Child
Browse files

Fix syntax

parent 0aff3629
...@@ -129,18 +129,13 @@ def count_zeros_fp32(parameters): ...@@ -129,18 +129,13 @@ def count_zeros_fp32(parameters):
# - grad should not be none # - grad should not be none
# - parameter should not be shared # - parameter should not be shared
# - should not be a replica due to tensor model parallelism # - should not be a replica due to tensor model parallelism
grads_to_count = [] total_num_zeros = 0.0
for param in parameters: for param in parameters:
grad_not_none = param.grad is not None grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param) is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_shared and is_not_tp_duplicate: if grad_not_none and is_not_shared and is_not_tp_duplicate:
grad = param.grad.detach() grad = param.grad.detach()
grads_to_count.append(grad)
total_num_zeros = 0.0
for grad in grads_to_count:
num_zeros = grad.numel() - torch.count_nonzero(grad) num_zeros = grad.numel() - torch.count_nonzero(grad)
total_num_zeros = num_zeros + total_num_zeros total_num_zeros = num_zeros + total_num_zeros
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment