Unverified Commit ccd1923f authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[T5] enable T5 fp16 (#9487)

* fix t5 fp16
parent 2aa9c2f2
...@@ -640,6 +640,11 @@ class T5Block(nn.Module): ...@@ -640,6 +640,11 @@ class T5Block(nn.Module):
hidden_states, present_key_value_state = self_attention_outputs[:2] hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training
if torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
do_cross_attention = self.is_decoder and encoder_hidden_states is not None do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention: if do_cross_attention:
# the actual query length is unknown for cross attention # the actual query length is unknown for cross attention
...@@ -661,6 +666,10 @@ class T5Block(nn.Module): ...@@ -661,6 +666,10 @@ class T5Block(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = cross_attention_outputs[0] hidden_states = cross_attention_outputs[0]
if torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Combine self attn and cross attn key value states # Combine self attn and cross attn key value states
if present_key_value_state is not None: if present_key_value_state is not None:
present_key_value_state = present_key_value_state + cross_attention_outputs[1] present_key_value_state = present_key_value_state + cross_attention_outputs[1]
...@@ -670,6 +679,9 @@ class T5Block(nn.Module): ...@@ -670,6 +679,9 @@ class T5Block(nn.Module):
# Apply Feed Forward layer # Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states) hidden_states = self.layer[-1](hidden_states)
if torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,) outputs = (hidden_states,)
outputs = outputs + (present_key_value_state,) + attention_outputs outputs = outputs + (present_key_value_state,) + attention_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment