"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "88111de07c40797aaca619be693616c3c4cda4bd"
Unverified Commit f44e2c2b authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Fix tf shared embedding (#17730)

* fix the naming

* from pt in test for now

* make style

* slow test and removed from_pt
parent 2eadb7e5
......@@ -30,7 +30,6 @@ from ...modeling_tf_utils import (
TFModelInputType,
TFPreTrainedModel,
TFSharedEmbeddings,
TFWrappedEmbeddings,
keras_serializable,
unpack_inputs,
)
......@@ -495,31 +494,15 @@ class TFOPTDecoder(tf.keras.layers.Layer):
self.padding_idx = config.pad_token_id
self.layerdrop = config.layerdrop
num_embeddings = config.max_position_embeddings
self.shared = TFSharedEmbeddings(
config.vocab_size, config.word_embed_proj_dim, config.pad_token_id, name="model.decoder.embed_tokens"
self.embed_tokens = TFSharedEmbeddings(
config.vocab_size, config.word_embed_proj_dim, config.pad_token_id, name="embed_tokens"
)
self.embed_positions = TFOPTLearnedPositionalEmbedding(
num_embeddings,
config.hidden_size,
name="embed_positions",
)
# set tf scope correctly
if load_weight_prefix is None:
load_weight_prefix = "decoder.embed_tokens"
with tf.compat.v1.variable_scope(load_weight_prefix) as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
embed_tokens.vocab_size = self.shared.vocab_size
embed_tokens.hidden_size = self.shared.hidden_size
self.embed_tokens = embed_tokens
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = tf.keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False)
self.project_in = tf.keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False)
......@@ -538,17 +521,11 @@ class TFOPTDecoder(tf.keras.layers.Layer):
self.embed_tokens = embed_tokens
def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("decoder.embed_tokens") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.set_embed_tokens(embed_tokens)
self.embed_tokens.vocab_size = new_embeddings.shape[0]
self.embed_tokens.weight = new_embeddings
def get_input_embeddings(self):
return self.shared
return self.embed_tokens
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length):
# create causal mask
......@@ -731,7 +708,7 @@ class TFOPTMainLayer(tf.keras.layers.Layer):
self.decoder = TFOPTDecoder(config, name="decoder")
def get_input_embeddings(self):
return self.decoder.shared
return self.decoder.embed_tokens
def set_input_embeddings(self, new_embeddings):
self.decoder.set_input_embeddings(new_embeddings)
......@@ -797,7 +774,7 @@ class TFOPTModel(TFOPTPreTrainedModel):
self.model = TFOPTMainLayer(config, name="model")
def get_input_embeddings(self):
return self.model.decoder.shared
return self.model.decoder.embed_tokens
def set_input_embeddings(self, new_embeddings):
self.model.set_input_embeddings(new_embeddings)
......@@ -1013,8 +990,7 @@ class TFOPTForCausalLM(TFOPTPreTrainedModel, TFCausalLanguageModelingLoss):
training=training,
)
logits = self.model.decoder.shared(outputs[0], mode="linear")
logits = self.model.decoder.embed_tokens(outputs[0], mode="linear")
loss = None
if labels is not None:
# shift labels to the left and cut last logit token
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment