Commit 9044bc42 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

removed count-zeros debuggables.

parent 7fccd6a1
......@@ -124,10 +124,7 @@ def count_zeros_fp32(parameters, model_parallel_group):
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
# >>>
# total_num_zeros = 0.0
total_num_zeros = torch.cuda.FloatTensor([0.0])
# <<<
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param)
......@@ -138,18 +135,9 @@ def count_zeros_fp32(parameters, model_parallel_group):
total_num_zeros = num_zeros + total_num_zeros
# Sum across all model-parallel GPUs.
# >>>
try:
torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM,
group=model_parallel_group)
except:
from lutil import pax
pax({
"total_num_zeros" : total_num_zeros,
"parameters" : parameters,
})
# <<<
torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM,
group=model_parallel_group)
total_num_zeros = total_num_zeros.item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment