Unverified Commit 8dc2030b authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] compute the grad norm in fp32 (#520)

parent 82986ca0
......@@ -120,5 +120,6 @@ def calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tens
if p == inf:
local_norm = max(par.grad.detach().abs().max() for par in parameters) # type: ignore
else:
local_norm = torch.norm(torch.stack([torch.norm(par.grad.detach(), p) for par in parameters]), p) # type: ignore
# Compute the norm in full precision no matter what
local_norm = torch.norm(torch.stack([torch.norm(par.grad.detach(), p, dtype=torch.float32) for par in parameters]), p).to(dtype=parameters[0].dtype) # type: ignore
return local_norm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment