Unverified Commit 208cce65 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #295 from nikitos9000/improve_fp16_stability

Improve TriangularMultiplicativeUpdate stability in fp16 mode
parents ee5d2c35 6625e8df
...@@ -392,8 +392,13 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -392,8 +392,13 @@ class TriangleMultiplicativeUpdate(nn.Module):
b = mask b = mask
b = b * self.sigmoid(self.linear_b_g(z)) b = b * self.sigmoid(self.linear_b_g(z))
b = b * self.linear_b_p(z) b = b * self.linear_b_p(z)
if(is_fp16_enabled()): # Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
a = a / a.std()
b = b / b.std()
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float()) x = self._combine_projections(a.float(), b.float())
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment