You need to sign in or sign up before continuing.
Unverified Commit 971da2e6 authored by Samuel Arcadinho's avatar Samuel Arcadinho Committed by GitHub
Browse files

Clamping hidden state values to allow FP16 (#19229)



* Clamping hidden state values to allow FP16

* Reformating

* Adding missing if condition

* Update src/transformers/models/longt5/modeling_longt5.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/models/longt5/modeling_longt5.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/models/longt5/modeling_longt5.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Formating file
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent 587d84b1
...@@ -1199,6 +1199,11 @@ class LongT5Block(nn.Module): ...@@ -1199,6 +1199,11 @@ class LongT5Block(nn.Module):
hidden_states, present_key_value_state = self_attention_outputs[:2] hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
do_cross_attention = self.is_decoder and encoder_hidden_states is not None do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention: if do_cross_attention:
# the actual query length is unknown for cross attention # the actual query length is unknown for cross attention
...@@ -1221,6 +1226,11 @@ class LongT5Block(nn.Module): ...@@ -1221,6 +1226,11 @@ class LongT5Block(nn.Module):
) )
hidden_states = cross_attention_outputs[0] hidden_states = cross_attention_outputs[0]
# clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Combine self attn and cross attn key value states # Combine self attn and cross attn key value states
if present_key_value_state is not None: if present_key_value_state is not None:
present_key_value_state = present_key_value_state + cross_attention_outputs[1] present_key_value_state = present_key_value_state + cross_attention_outputs[1]
...@@ -1231,6 +1241,11 @@ class LongT5Block(nn.Module): ...@@ -1231,6 +1241,11 @@ class LongT5Block(nn.Module):
# Apply Feed Forward layer # Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states) hidden_states = self.layer[-1](hidden_states)
# clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,) outputs = (hidden_states,)
if use_cache: if use_cache:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment