Unverified Commit 4f42bbb0 authored by Shaden Smith's avatar Shaden Smith Committed by GitHub
Browse files

check for model_parallel attr (#204)

parent 7dc209c6
...@@ -132,7 +132,9 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): ...@@ -132,7 +132,9 @@ def get_grad_norm(parameters, norm_type=2, mpu=None):
total_norm = 0. total_norm = 0.
for p in parameters: for p in parameters:
if mpu is not None: if mpu is not None:
if p.model_parallel or (mpu.get_model_parallel_rank() == 0): if (mpu.get_model_parallel_rank() == 0) or (hasattr(p,
'model_parallel')
and p.model_parallel):
param_norm = p.grad.data.float().norm(norm_type) param_norm = p.grad.data.float().norm(norm_type)
total_norm += param_norm.item()**norm_type total_norm += param_norm.item()**norm_type
else: else:
...@@ -188,7 +190,9 @@ def get_weight_norm(parameters, norm_type=2, mpu=None): ...@@ -188,7 +190,9 @@ def get_weight_norm(parameters, norm_type=2, mpu=None):
total_norm = 0. total_norm = 0.
for p in parameters: for p in parameters:
if mpu is not None: if mpu is not None:
if p.model_parallel or (mpu.get_model_parallel_rank() == 0): if (mpu.get_model_parallel_rank() == 0) or (hasattr(p,
'model_parallel')
and p.model_parallel):
try: try:
param_norm = float(torch.norm(p, norm_type, dtype=torch.float32)) param_norm = float(torch.norm(p, norm_type, dtype=torch.float32))
except TypeError as err: except TypeError as err:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment