"tests/test_configuration_utils.py" did not exist on "923110b74fbef83909035b88c7eb6f7c4b6a8397"
Unverified Commit cbaaa2f6 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Flax dtype-dependent numerical masking (#21197)

parent 0b86e330
...@@ -312,7 +312,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module): ...@@ -312,7 +312,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
...@@ -1859,7 +1859,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module): ...@@ -1859,7 +1859,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment