Unverified Commit 4f42bbb0 authored by Shaden Smith's avatar Shaden Smith Committed by GitHub
Browse files

check for model_parallel attr (#204)

parent 7dc209c6
......@@ -132,7 +132,9 @@ def get_grad_norm(parameters, norm_type=2, mpu=None):
total_norm = 0.
for p in parameters:
if mpu is not None:
if p.model_parallel or (mpu.get_model_parallel_rank() == 0):
if (mpu.get_model_parallel_rank() == 0) or (hasattr(p,
'model_parallel')
and p.model_parallel):
param_norm = p.grad.data.float().norm(norm_type)
total_norm += param_norm.item()**norm_type
else:
......@@ -188,7 +190,9 @@ def get_weight_norm(parameters, norm_type=2, mpu=None):
total_norm = 0.
for p in parameters:
if mpu is not None:
if p.model_parallel or (mpu.get_model_parallel_rank() == 0):
if (mpu.get_model_parallel_rank() == 0) or (hasattr(p,
'model_parallel')
and p.model_parallel):
try:
param_norm = float(torch.norm(p, norm_type, dtype=torch.float32))
except TypeError as err:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment