"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5ff6f853d7d5247924d94e278a4cedc37a5885a4"
Unverified Commit ac2f6674 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[FLAX] Add dtype to embedding for bert/bart/opt/t5 (#20340)

* [FLAX] Add dtype to embedding for bert/bart/opt/t5

* Fix all copies

* Add a test case
parent 667ccea7
...@@ -715,6 +715,7 @@ class FlaxBartEncoder(nn.Module): ...@@ -715,6 +715,7 @@ class FlaxBartEncoder(nn.Module):
self.config.max_position_embeddings + self.offset, self.config.max_position_embeddings + self.offset,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype) self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
...@@ -779,6 +780,7 @@ class FlaxBartDecoder(nn.Module): ...@@ -779,6 +780,7 @@ class FlaxBartDecoder(nn.Module):
self.config.max_position_embeddings + self.offset, self.config.max_position_embeddings + self.offset,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype) self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
...@@ -842,6 +844,7 @@ class FlaxBartModule(nn.Module): ...@@ -842,6 +844,7 @@ class FlaxBartModule(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
...@@ -1888,6 +1891,7 @@ class FlaxBartDecoderWrapper(nn.Module): ...@@ -1888,6 +1891,7 @@ class FlaxBartDecoderWrapper(nn.Module):
self.config.vocab_size, self.config.vocab_size,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype) self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype)
......
...@@ -187,16 +187,19 @@ class FlaxBertEmbeddings(nn.Module): ...@@ -187,16 +187,19 @@ class FlaxBertEmbeddings(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.position_embeddings = nn.Embed( self.position_embeddings = nn.Embed(
self.config.max_position_embeddings, self.config.max_position_embeddings,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.token_type_embeddings = nn.Embed( self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size, self.config.type_vocab_size,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
......
...@@ -205,16 +205,19 @@ class FlaxBigBirdEmbeddings(nn.Module): ...@@ -205,16 +205,19 @@ class FlaxBigBirdEmbeddings(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.position_embeddings = nn.Embed( self.position_embeddings = nn.Embed(
self.config.max_position_embeddings, self.config.max_position_embeddings,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.token_type_embeddings = nn.Embed( self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size, self.config.type_vocab_size,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
......
...@@ -817,6 +817,7 @@ class FlaxBlenderbotModule(nn.Module): ...@@ -817,6 +817,7 @@ class FlaxBlenderbotModule(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.encoder = FlaxBlenderbotEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) self.encoder = FlaxBlenderbotEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
......
...@@ -815,6 +815,7 @@ class FlaxBlenderbotSmallModule(nn.Module): ...@@ -815,6 +815,7 @@ class FlaxBlenderbotSmallModule(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.encoder = FlaxBlenderbotSmallEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) self.encoder = FlaxBlenderbotSmallEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
......
...@@ -368,6 +368,7 @@ class FlaxLongT5Attention(nn.Module): ...@@ -368,6 +368,7 @@ class FlaxLongT5Attention(nn.Module):
self.relative_attention_num_buckets, self.relative_attention_num_buckets,
self.n_heads, self.n_heads,
embedding_init=jax.nn.initializers.normal(kv_init_std), embedding_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype,
) )
@staticmethod @staticmethod
...@@ -2032,6 +2033,7 @@ class FlaxLongT5Module(nn.Module): ...@@ -2032,6 +2033,7 @@ class FlaxLongT5Module(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
dtype=self.dtype,
) )
encoder_config = copy.deepcopy(self.config) encoder_config = copy.deepcopy(self.config)
...@@ -2160,6 +2162,7 @@ class FlaxLongT5ForConditionalGenerationModule(nn.Module): ...@@ -2160,6 +2162,7 @@ class FlaxLongT5ForConditionalGenerationModule(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor), embedding_init=jax.nn.initializers.normal(self.config.initializer_factor),
dtype=self.dtype,
) )
encoder_config = copy.deepcopy(self.config) encoder_config = copy.deepcopy(self.config)
......
...@@ -881,6 +881,7 @@ class FlaxMBartModule(nn.Module): ...@@ -881,6 +881,7 @@ class FlaxMBartModule(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.encoder = FlaxMBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) self.encoder = FlaxMBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
......
...@@ -436,12 +436,14 @@ class FlaxOPTDecoder(nn.Module): ...@@ -436,12 +436,14 @@ class FlaxOPTDecoder(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.word_embed_proj_dim, self.config.word_embed_proj_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.embed_positions = FlaxOPTLearnedPositionalEmbedding( self.embed_positions = FlaxOPTLearnedPositionalEmbedding(
self.config.max_position_embeddings, self.config.max_position_embeddings,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
if self.config.word_embed_proj_dim != self.config.hidden_size: if self.config.word_embed_proj_dim != self.config.hidden_size:
......
...@@ -831,6 +831,7 @@ class FlaxPegasusModule(nn.Module): ...@@ -831,6 +831,7 @@ class FlaxPegasusModule(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
......
...@@ -147,16 +147,19 @@ class FlaxRobertaEmbeddings(nn.Module): ...@@ -147,16 +147,19 @@ class FlaxRobertaEmbeddings(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.position_embeddings = nn.Embed( self.position_embeddings = nn.Embed(
self.config.max_position_embeddings, self.config.max_position_embeddings,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.token_type_embeddings = nn.Embed( self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size, self.config.type_vocab_size,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
......
...@@ -228,6 +228,7 @@ class FlaxT5Attention(nn.Module): ...@@ -228,6 +228,7 @@ class FlaxT5Attention(nn.Module):
self.relative_attention_num_buckets, self.relative_attention_num_buckets,
self.n_heads, self.n_heads,
embedding_init=jax.nn.initializers.normal(kv_init_std), embedding_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype,
) )
@staticmethod @staticmethod
...@@ -1292,6 +1293,7 @@ class FlaxT5Module(nn.Module): ...@@ -1292,6 +1293,7 @@ class FlaxT5Module(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
dtype=self.dtype,
) )
encoder_config = copy.deepcopy(self.config) encoder_config = copy.deepcopy(self.config)
...@@ -1417,6 +1419,7 @@ class FlaxT5EncoderModule(nn.Module): ...@@ -1417,6 +1419,7 @@ class FlaxT5EncoderModule(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
dtype=self.dtype,
) )
encoder_config = copy.deepcopy(self.config) encoder_config = copy.deepcopy(self.config)
...@@ -1512,6 +1515,7 @@ class FlaxT5ForConditionalGenerationModule(nn.Module): ...@@ -1512,6 +1515,7 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor), embedding_init=jax.nn.initializers.normal(self.config.initializer_factor),
dtype=self.dtype,
) )
encoder_config = copy.deepcopy(self.config) encoder_config = copy.deepcopy(self.config)
......
...@@ -865,6 +865,21 @@ class FlaxT5ModelIntegrationTests(unittest.TestCase): ...@@ -865,6 +865,21 @@ class FlaxT5ModelIntegrationTests(unittest.TestCase):
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
self.assertTrue(output_str == "Hello there!") self.assertTrue(output_str == "Hello there!")
@slow
def test_small_generation_bfloat16(self):
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-small", dtype=jnp.bfloat16)
model.config.max_length = 8
model.config.num_beams = 1
model.config.do_sample = False
tokenizer = T5Tokenizer.from_pretrained("t5-small")
input_ids = tokenizer("summarize: Hello there", return_tensors="np").input_ids
sequences = model.generate(input_ids).sequences
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
self.assertTrue(output_str == "Hello there!")
@slow @slow
def test_summarization(self): def test_summarization(self):
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base") model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment