Unverified Commit 2596f95e authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Fix Embedding Module Bug in Flax Models (#15920)

parent 1a62b25c
......@@ -697,8 +697,8 @@ class FlaxBartClassificationHead(nn.Module):
class FlaxBartEncoder(nn.Module):
config: BartConfig
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
embed_tokens: Optional[nn.Embed] = None
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -708,13 +708,6 @@ class FlaxBartEncoder(nn.Module):
self.max_source_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
if self.embed_tokens is None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
......@@ -768,8 +761,8 @@ class FlaxBartEncoder(nn.Module):
class FlaxBartDecoder(nn.Module):
config: BartConfig
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
embed_tokens: Optional[nn.Embed] = None
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -779,13 +772,6 @@ class FlaxBartDecoder(nn.Module):
self.max_target_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
if self.embed_tokens is None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
......
......@@ -661,8 +661,8 @@ class FlaxBlenderbotDecoderLayerCollection(nn.Module):
class FlaxBlenderbotEncoder(nn.Module):
config: BlenderbotConfig
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
embed_tokens: Optional[nn.Embed] = None
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -672,13 +672,6 @@ class FlaxBlenderbotEncoder(nn.Module):
self.max_source_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
if self.embed_tokens is None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.embed_positions = nn.Embed(
self.config.max_position_embeddings,
embed_dim,
......@@ -730,8 +723,8 @@ class FlaxBlenderbotEncoder(nn.Module):
class FlaxBlenderbotDecoder(nn.Module):
config: BlenderbotConfig
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
embed_tokens: Optional[nn.Embed] = None
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -741,13 +734,6 @@ class FlaxBlenderbotDecoder(nn.Module):
self.max_target_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
if self.embed_tokens is None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.embed_positions = nn.Embed(
self.config.max_position_embeddings,
embed_dim,
......
......@@ -674,8 +674,8 @@ class FlaxBlenderbotSmallDecoderLayerCollection(nn.Module):
class FlaxBlenderbotSmallEncoder(nn.Module):
config: BlenderbotSmallConfig
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
embed_tokens: Optional[nn.Embed] = None
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -685,13 +685,6 @@ class FlaxBlenderbotSmallEncoder(nn.Module):
self.max_source_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
if self.embed_tokens is None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.embed_positions = nn.Embed(
self.config.max_position_embeddings,
embed_dim,
......@@ -742,8 +735,8 @@ class FlaxBlenderbotSmallEncoder(nn.Module):
class FlaxBlenderbotSmallDecoder(nn.Module):
config: BlenderbotSmallConfig
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
embed_tokens: Optional[nn.Embed] = None
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -753,13 +746,6 @@ class FlaxBlenderbotSmallDecoder(nn.Module):
self.max_target_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
if self.embed_tokens is None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.embed_positions = nn.Embed(
self.config.max_position_embeddings,
embed_dim,
......
......@@ -684,8 +684,8 @@ class FlaxMarianDecoderLayerCollection(nn.Module):
class FlaxMarianEncoder(nn.Module):
config: MarianConfig
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
embed_tokens: Optional[nn.Embed] = None
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -694,13 +694,6 @@ class FlaxMarianEncoder(nn.Module):
self.max_source_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
if self.embed_tokens is None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
self.layers = FlaxMarianEncoderLayerCollection(self.config, self.dtype)
......@@ -747,8 +740,8 @@ class FlaxMarianEncoder(nn.Module):
class FlaxMarianDecoder(nn.Module):
config: MarianConfig
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
embed_tokens: Optional[nn.Embed] = None
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -757,13 +750,6 @@ class FlaxMarianDecoder(nn.Module):
self.max_target_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
if self.embed_tokens is None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
self.layers = FlaxMarianDecoderLayerCollection(self.config, self.dtype)
......
......@@ -712,8 +712,8 @@ class FlaxMBartClassificationHead(nn.Module):
class FlaxMBartEncoder(nn.Module):
config: MBartConfig
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
embed_tokens: Optional[nn.Embed] = None
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -723,13 +723,6 @@ class FlaxMBartEncoder(nn.Module):
self.max_source_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
if self.embed_tokens is None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
......@@ -787,8 +780,8 @@ class FlaxMBartEncoder(nn.Module):
class FlaxMBartDecoder(nn.Module):
config: MBartConfig
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
embed_tokens: Optional[nn.Embed] = None
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -798,13 +791,6 @@ class FlaxMBartDecoder(nn.Module):
self.max_target_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
if self.embed_tokens is None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
......
......@@ -677,8 +677,8 @@ class FlaxPegasusDecoderLayerCollection(nn.Module):
class FlaxPegasusEncoder(nn.Module):
config: PegasusConfig
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
embed_tokens: Optional[nn.Embed] = None
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -688,13 +688,6 @@ class FlaxPegasusEncoder(nn.Module):
self.max_source_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
if self.embed_tokens is None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.embed_positions = create_sinusoidal_positions(
self.config.max_position_embeddings, embed_dim, dtype=self.dtype
)
......@@ -746,8 +739,8 @@ class FlaxPegasusEncoder(nn.Module):
class FlaxPegasusDecoder(nn.Module):
config: PegasusConfig
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
embed_tokens: Optional[nn.Embed] = None
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -757,13 +750,6 @@ class FlaxPegasusDecoder(nn.Module):
self.max_target_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
if self.embed_tokens is None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.embed_positions = create_sinusoidal_positions(
self.config.max_position_embeddings, embed_dim, dtype=self.dtype
)
......
......@@ -709,19 +709,12 @@ class FlaxT5BlockCollection(nn.Module):
class FlaxT5Stack(nn.Module):
config: T5Config
embed_tokens: Optional[nn.Embed] = None
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.causal = self.config.causal
if self.embed_tokens is None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.block = FlaxT5BlockCollection(self.config, dtype=self.dtype)
self.final_layer_norm = FlaxT5LayerNorm(
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment