Commit a9b1fc0a authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

finished cleaning clip_grads.py

parent 9546d8f0
......@@ -21,18 +21,10 @@ from torch._six import inf
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
# >>>
# from megatron import mpu
# <<<
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
# >>>
from lutil import pax, tp
DEBUG_ITERATION = 1
# <<<
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2,
model_parallel_group=None,
ITERATION=None):
......@@ -56,10 +48,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2,
Total norm of the parameters (viewed as a single vector).
"""
# >>>
# raise Exception("currently debugging ... don't call me.")
# <<<
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
......@@ -119,17 +107,9 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2,
total_norm += grad_norm ** norm_type
# Sum across all model-parallel GPUs.
# >>>
from megatron import get_args
args = get_args()
if args.use_distributed_optimizer:
torch.distributed.all_reduce(total_norm,
op=torch.distributed.ReduceOp.SUM)
else:
torch.distributed.all_reduce(total_norm,
op=torch.distributed.ReduceOp.SUM,
group=model_parallel_group)
# <<<
torch.distributed.all_reduce(total_norm,
op=torch.distributed.ReduceOp.SUM,
group=model_parallel_group)
total_norm = total_norm.item() ** (1.0 / norm_type)
# Scale.
......@@ -164,18 +144,9 @@ def count_zeros_fp32(parameters, model_parallel_group):
total_num_zeros = num_zeros + total_num_zeros
# Sum across all model-parallel GPUs.
# >>>
from megatron import get_args
args = get_args()
if args.use_distributed_optimizer:
torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM)
# pax({"total_num_zeros": total_num_zeros.item()})
else:
torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM,
group=model_parallel_group)
# <<<
torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM,
group=model_parallel_group)
total_num_zeros = total_num_zeros.item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment