"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "3258ff93304078b9e27d752e6c19d3813f664855"
Unverified Commit cbaaa2f6 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Flax dtype-dependent numerical masking (#21197)

parent 0b86e330
...@@ -245,7 +245,7 @@ class FlaxAlbertSelfAttention(nn.Module): ...@@ -245,7 +245,7 @@ class FlaxAlbertSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
...@@ -371,7 +371,7 @@ class FlaxBartAttention(nn.Module): ...@@ -371,7 +371,7 @@ class FlaxBartAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
...@@ -358,7 +358,7 @@ class FlaxBertSelfAttention(nn.Module): ...@@ -358,7 +358,7 @@ class FlaxBertSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
...@@ -380,7 +380,7 @@ class FlaxBigBirdSelfAttention(nn.Module): ...@@ -380,7 +380,7 @@ class FlaxBigBirdSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
...@@ -359,7 +359,7 @@ class FlaxBlenderbotAttention(nn.Module): ...@@ -359,7 +359,7 @@ class FlaxBlenderbotAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
...@@ -371,7 +371,7 @@ class FlaxBlenderbotSmallAttention(nn.Module): ...@@ -371,7 +371,7 @@ class FlaxBlenderbotSmallAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
...@@ -315,7 +315,7 @@ class FlaxCLIPAttention(nn.Module): ...@@ -315,7 +315,7 @@ class FlaxCLIPAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e4).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
...@@ -326,7 +326,7 @@ class FlaxElectraSelfAttention(nn.Module): ...@@ -326,7 +326,7 @@ class FlaxElectraSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
...@@ -255,7 +255,7 @@ class FlaxGPT2Attention(nn.Module): ...@@ -255,7 +255,7 @@ class FlaxGPT2Attention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e4).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
...@@ -223,7 +223,7 @@ class FlaxGPTNeoSelfAttention(nn.Module): ...@@ -223,7 +223,7 @@ class FlaxGPTNeoSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
# usual dot product attention # usual dot product attention
......
...@@ -270,7 +270,7 @@ class FlaxGPTJAttention(nn.Module): ...@@ -270,7 +270,7 @@ class FlaxGPTJAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
# usual dot product attention # usual dot product attention
......
...@@ -381,7 +381,7 @@ class FlaxMarianAttention(nn.Module): ...@@ -381,7 +381,7 @@ class FlaxMarianAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
...@@ -383,7 +383,7 @@ class FlaxMBartAttention(nn.Module): ...@@ -383,7 +383,7 @@ class FlaxMBartAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
...@@ -245,7 +245,7 @@ class FlaxOPTAttention(nn.Module): ...@@ -245,7 +245,7 @@ class FlaxOPTAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
...@@ -375,7 +375,7 @@ class FlaxPegasusAttention(nn.Module): ...@@ -375,7 +375,7 @@ class FlaxPegasusAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
...@@ -319,7 +319,7 @@ class FlaxRobertaSelfAttention(nn.Module): ...@@ -319,7 +319,7 @@ class FlaxRobertaSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
...@@ -321,7 +321,7 @@ class FlaxRobertaPreLayerNormSelfAttention(nn.Module): ...@@ -321,7 +321,7 @@ class FlaxRobertaPreLayerNormSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
...@@ -240,7 +240,7 @@ class FlaxRoFormerSelfAttention(nn.Module): ...@@ -240,7 +240,7 @@ class FlaxRoFormerSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
...@@ -497,7 +497,7 @@ class FlaxWav2Vec2Attention(nn.Module): ...@@ -497,7 +497,7 @@ class FlaxWav2Vec2Attention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
...@@ -329,7 +329,7 @@ class FlaxXLMRobertaSelfAttention(nn.Module): ...@@ -329,7 +329,7 @@ class FlaxXLMRobertaSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment