Unverified Commit 9e80f972 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Enable pegasus fp16 by clamping large activations (#7243)

* Clean clamp

* boom boom

* Take some other changes

* boom boom

* boom boom

* boom boom

* one chg

* fix test

* Use finfo

* style
parent be51c103
...@@ -269,6 +269,9 @@ class EncoderLayer(nn.Module): ...@@ -269,6 +269,9 @@ class EncoderLayer(nn.Module):
x = residual + x x = residual + x
if not self.normalize_before: if not self.normalize_before:
x = self.final_layer_norm(x) x = self.final_layer_norm(x)
if torch.isinf(x).any() or torch.isnan(x).any():
clamp_value = torch.finfo(x.dtype).max - 1000
x = torch.clamp(x, min=-clamp_value, max=clamp_value)
return x, attn_weights return x, attn_weights
......
...@@ -47,9 +47,11 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest): ...@@ -47,9 +47,11 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
# Demonstrate fp16 issue, Contributions welcome! # Demonstrate fp16 issue, Contributions welcome!
self.model.half() self.model.half()
translated_tokens_fp16 = self.model.generate(**inputs, max_length=10) translated_tokens_fp16 = self.model.generate(**inputs, max_length=10)
decoded = self.tokenizer.batch_decode(translated_tokens_fp16, skip_special_tokens=True) decoded_fp16 = self.tokenizer.batch_decode(translated_tokens_fp16, skip_special_tokens=True)
bad_fp16_result = ["unk_7unk_7unk_7unk_7unk_7unk_7unk_7", "unk_7unk_7unk_7unk_7unk_7unk_7unk_7"] assert decoded_fp16 == [
self.assertListEqual(decoded, bad_fp16_result) "California's largest electricity provider has begun",
"N-Dubz have revealed they were",
]
class PegasusConfigTests(unittest.TestCase): class PegasusConfigTests(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment