Unverified Commit 54b3ad89 authored by littsk's avatar littsk Committed by GitHub
Browse files

[hotfix] fix norm type error in zero optimizer (#4795)

parent da15fdb9
......@@ -221,8 +221,8 @@ def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGro
else:
total_norm = 0.0
for g in gradients:
param_norm = g.data.double().norm(2)
total_norm += param_norm.item() ** 2
param_norm = g.data.double().norm(norm_type)
total_norm += param_norm.item() ** norm_type
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment