Unverified Commit 2d2ed2cc authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[T5] Fix speed degradation bug t5 (#10496)

* fix speed degradation bug t5

* fix for all models

* fix code quality
parent 5dc303e2
...@@ -319,7 +319,9 @@ class BartEncoderLayer(nn.Module): ...@@ -319,7 +319,9 @@ class BartEncoderLayer(nn.Module):
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
......
...@@ -322,7 +322,9 @@ class BlenderbotEncoderLayer(nn.Module): ...@@ -322,7 +322,9 @@ class BlenderbotEncoderLayer(nn.Module):
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
......
...@@ -320,7 +320,9 @@ class BlenderbotSmallEncoderLayer(nn.Module): ...@@ -320,7 +320,9 @@ class BlenderbotSmallEncoderLayer(nn.Module):
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
......
...@@ -925,7 +925,9 @@ class LEDEncoderLayer(nn.Module): ...@@ -925,7 +925,9 @@ class LEDEncoderLayer(nn.Module):
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
return (hidden_states,) + attn_outputs[1:] return (hidden_states,) + attn_outputs[1:]
......
...@@ -337,7 +337,9 @@ class MarianEncoderLayer(nn.Module): ...@@ -337,7 +337,9 @@ class MarianEncoderLayer(nn.Module):
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
......
...@@ -326,7 +326,9 @@ class MBartEncoderLayer(nn.Module): ...@@ -326,7 +326,9 @@ class MBartEncoderLayer(nn.Module):
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
......
...@@ -337,7 +337,9 @@ class PegasusEncoderLayer(nn.Module): ...@@ -337,7 +337,9 @@ class PegasusEncoderLayer(nn.Module):
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
......
...@@ -643,7 +643,7 @@ class T5Block(nn.Module): ...@@ -643,7 +643,7 @@ class T5Block(nn.Module):
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
if torch.isinf(hidden_states).any(): if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
...@@ -668,7 +668,9 @@ class T5Block(nn.Module): ...@@ -668,7 +668,9 @@ class T5Block(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = cross_attention_outputs[0] hidden_states = cross_attention_outputs[0]
if torch.isinf(hidden_states).any():
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
...@@ -681,9 +683,12 @@ class T5Block(nn.Module): ...@@ -681,9 +683,12 @@ class T5Block(nn.Module):
# Apply Feed Forward layer # Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states) hidden_states = self.layer[-1](hidden_states)
if torch.isinf(hidden_states).any():
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,) outputs = (hidden_states,)
outputs = outputs + (present_key_value_state,) + attention_outputs outputs = outputs + (present_key_value_state,) + attention_outputs
......
...@@ -1824,7 +1824,7 @@ class {{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module): ...@@ -1824,7 +1824,7 @@ class {{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment