Unverified Commit 2df5ca2d authored by Joshua Meier's avatar Joshua Meier Committed by GitHub
Browse files

[fix] OSS - resolve fp16 overflow in clip grad norm (#263)

parent 2d9243bf
...@@ -259,7 +259,7 @@ class OSS(Optimizer): ...@@ -259,7 +259,7 @@ class OSS(Optimizer):
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.group) dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.group)
else: else:
local_norm = torch.norm( local_norm = torch.norm(
input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type).to(self._device) for p in local_params]), # type: ignore input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type, dtype=torch.float32).to(self._device) for p in local_params]), # type: ignore
p=norm_type, p=norm_type,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment