"examples/vscode:/vscode.git/clone" did not exist on "6f72e71f97fd7bd114b10b4322e72ecbda283f3b"
Unverified Commit 7c4999e4 authored by Prathik Rao's avatar Prathik Rao Committed by GitHub
Browse files

t5 remove data dependency (#22097)



* t5 remove data dependency

* make style

* make fix-copies

---------
Co-authored-by: default avatarPrathik Rao <prathikrao@microsoft.com>
parent 16121bae
...@@ -566,8 +566,12 @@ class MT5Block(nn.Module): ...@@ -566,8 +566,12 @@ class MT5Block(nn.Module):
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): if hidden_states.dtype == torch.float16:
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
do_cross_attention = self.is_decoder and encoder_hidden_states is not None do_cross_attention = self.is_decoder and encoder_hidden_states is not None
...@@ -593,8 +597,12 @@ class MT5Block(nn.Module): ...@@ -593,8 +597,12 @@ class MT5Block(nn.Module):
hidden_states = cross_attention_outputs[0] hidden_states = cross_attention_outputs[0]
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): if hidden_states.dtype == torch.float16:
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Combine self attn and cross attn key value states # Combine self attn and cross attn key value states
...@@ -608,8 +616,12 @@ class MT5Block(nn.Module): ...@@ -608,8 +616,12 @@ class MT5Block(nn.Module):
hidden_states = self.layer[-1](hidden_states) hidden_states = self.layer[-1](hidden_states)
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): if hidden_states.dtype == torch.float16:
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,) outputs = (hidden_states,)
......
...@@ -703,8 +703,12 @@ class T5Block(nn.Module): ...@@ -703,8 +703,12 @@ class T5Block(nn.Module):
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): if hidden_states.dtype == torch.float16:
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
do_cross_attention = self.is_decoder and encoder_hidden_states is not None do_cross_attention = self.is_decoder and encoder_hidden_states is not None
...@@ -730,8 +734,12 @@ class T5Block(nn.Module): ...@@ -730,8 +734,12 @@ class T5Block(nn.Module):
hidden_states = cross_attention_outputs[0] hidden_states = cross_attention_outputs[0]
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): if hidden_states.dtype == torch.float16:
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Combine self attn and cross attn key value states # Combine self attn and cross attn key value states
...@@ -745,8 +753,12 @@ class T5Block(nn.Module): ...@@ -745,8 +753,12 @@ class T5Block(nn.Module):
hidden_states = self.layer[-1](hidden_states) hidden_states = self.layer[-1](hidden_states)
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): if hidden_states.dtype == torch.float16:
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,) outputs = (hidden_states,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment