"...retiarii_test/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "4784cc6c1abb1a15b8a73d1022e191ccc26272e9"
Unverified Commit 860d11ff authored by Wissam Antoun's avatar Wissam Antoun Committed by GitHub
Browse files

Fix Debertav2 embed_proj (#24205)

* MLM prediction head output size from embed_size

Take the output size of the dense projection layer from embedding_size instead of hidden_size since there could be a projection of the input embedding into hidden_size if they are different

* project TFDebertaV2 mlm output to embedding size

embedding size can be different that hidden_size, so the final layer needs to project back to embedding size. like in ELECTRA or DeBERTaV3 style pertaining.

This should solve an error that occurs when loading models like "almanach/camemberta-base-generator".

* fix the same issue for reshaping after projection

* fix layernorm size

* add self.embedding_size to scope

* fix embed_proj scope name

* apply the same changes to TF Deberta

* add the changes to deberta

* added self.embedding_size instead of config.embedding_size

* added the same change to debertav2

* added coppied from deberta to deberta2 model

* config.embedding_size fix

* black

* fix deberta config name
parent a04ebc8b
...@@ -1100,16 +1100,17 @@ class DebertaForMaskedLM(DebertaPreTrainedModel): ...@@ -1100,16 +1100,17 @@ class DebertaForMaskedLM(DebertaPreTrainedModel):
) )
# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta
class DebertaPredictionHeadTransform(nn.Module): class DebertaPredictionHeadTransform(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
self.dense = nn.Linear(config.hidden_size, self.embedding_size)
if isinstance(config.hidden_act, str): if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act] self.transform_act_fn = ACT2FN[config.hidden_act]
else: else:
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(self.embedding_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
...@@ -1118,15 +1119,15 @@ class DebertaPredictionHeadTransform(nn.Module): ...@@ -1118,15 +1119,15 @@ class DebertaPredictionHeadTransform(nn.Module):
return hidden_states return hidden_states
# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta
class DebertaLMPredictionHead(nn.Module): class DebertaLMPredictionHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.transform = DebertaPredictionHeadTransform(config) self.transform = DebertaPredictionHeadTransform(config)
self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
# The output weights are the same as the input embeddings, but there is # The output weights are the same as the input embeddings, but there is
# an output-only bias for each token. # an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.decoder = nn.Linear(self.embedding_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) self.bias = nn.Parameter(torch.zeros(config.vocab_size))
......
...@@ -726,7 +726,12 @@ class TFDebertaEmbeddings(tf.keras.layers.Layer): ...@@ -726,7 +726,12 @@ class TFDebertaEmbeddings(tf.keras.layers.Layer):
self.position_biased_input = getattr(config, "position_biased_input", True) self.position_biased_input = getattr(config, "position_biased_input", True)
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
if self.embedding_size != config.hidden_size: if self.embedding_size != config.hidden_size:
self.embed_proj = tf.keras.layers.Dense(config.hidden_size, use_bias=False) self.embed_proj = tf.keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
name="embed_proj",
use_bias=False,
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="dropout") self.dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="dropout")
...@@ -820,8 +825,10 @@ class TFDebertaPredictionHeadTransform(tf.keras.layers.Layer): ...@@ -820,8 +825,10 @@ class TFDebertaPredictionHeadTransform(tf.keras.layers.Layer):
def __init__(self, config: DebertaConfig, **kwargs): def __init__(self, config: DebertaConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
units=config.hidden_size, units=self.embedding_size,
kernel_initializer=get_initializer(config.initializer_range), kernel_initializer=get_initializer(config.initializer_range),
name="dense", name="dense",
) )
...@@ -845,7 +852,7 @@ class TFDebertaLMPredictionHead(tf.keras.layers.Layer): ...@@ -845,7 +852,7 @@ class TFDebertaLMPredictionHead(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config self.config = config
self.hidden_size = config.hidden_size self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
self.transform = TFDebertaPredictionHeadTransform(config, name="transform") self.transform = TFDebertaPredictionHeadTransform(config, name="transform")
...@@ -875,7 +882,7 @@ class TFDebertaLMPredictionHead(tf.keras.layers.Layer): ...@@ -875,7 +882,7 @@ class TFDebertaLMPredictionHead(tf.keras.layers.Layer):
def call(self, hidden_states: tf.Tensor) -> tf.Tensor: def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.transform(hidden_states=hidden_states) hidden_states = self.transform(hidden_states=hidden_states)
seq_length = shape_list(hidden_states)[1] seq_length = shape_list(hidden_states)[1]
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
......
...@@ -1199,16 +1199,18 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): ...@@ -1199,16 +1199,18 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
) )
# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta # Copied from transformers.models.deberta.modeling_deberta.DebertaPredictionHeadTransform with Deberta->DebertaV2
class DebertaV2PredictionHeadTransform(nn.Module): class DebertaV2PredictionHeadTransform(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
self.dense = nn.Linear(config.hidden_size, self.embedding_size)
if isinstance(config.hidden_act, str): if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act] self.transform_act_fn = ACT2FN[config.hidden_act]
else: else:
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(self.embedding_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
...@@ -1217,15 +1219,16 @@ class DebertaV2PredictionHeadTransform(nn.Module): ...@@ -1217,15 +1219,16 @@ class DebertaV2PredictionHeadTransform(nn.Module):
return hidden_states return hidden_states
# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta # Copied from transformers.models.deberta.modeling_deberta.DebertaLMPredictionHead with Deberta->DebertaV2
class DebertaV2LMPredictionHead(nn.Module): class DebertaV2LMPredictionHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.transform = DebertaV2PredictionHeadTransform(config) self.transform = DebertaV2PredictionHeadTransform(config)
self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
# The output weights are the same as the input embeddings, but there is # The output weights are the same as the input embeddings, but there is
# an output-only bias for each token. # an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.decoder = nn.Linear(self.embedding_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) self.bias = nn.Parameter(torch.zeros(config.vocab_size))
......
...@@ -816,7 +816,12 @@ class TFDebertaV2Embeddings(tf.keras.layers.Layer): ...@@ -816,7 +816,12 @@ class TFDebertaV2Embeddings(tf.keras.layers.Layer):
self.position_biased_input = getattr(config, "position_biased_input", True) self.position_biased_input = getattr(config, "position_biased_input", True)
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
if self.embedding_size != config.hidden_size: if self.embedding_size != config.hidden_size:
self.embed_proj = tf.keras.layers.Dense(config.hidden_size, use_bias=False) self.embed_proj = tf.keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
name="embed_proj",
use_bias=False,
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout") self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout")
...@@ -911,8 +916,10 @@ class TFDebertaV2PredictionHeadTransform(tf.keras.layers.Layer): ...@@ -911,8 +916,10 @@ class TFDebertaV2PredictionHeadTransform(tf.keras.layers.Layer):
def __init__(self, config: DebertaV2Config, **kwargs): def __init__(self, config: DebertaV2Config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
units=config.hidden_size, units=self.embedding_size,
kernel_initializer=get_initializer(config.initializer_range), kernel_initializer=get_initializer(config.initializer_range),
name="dense", name="dense",
) )
...@@ -937,7 +944,7 @@ class TFDebertaV2LMPredictionHead(tf.keras.layers.Layer): ...@@ -937,7 +944,7 @@ class TFDebertaV2LMPredictionHead(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config self.config = config
self.hidden_size = config.hidden_size self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
self.transform = TFDebertaV2PredictionHeadTransform(config, name="transform") self.transform = TFDebertaV2PredictionHeadTransform(config, name="transform")
...@@ -967,7 +974,7 @@ class TFDebertaV2LMPredictionHead(tf.keras.layers.Layer): ...@@ -967,7 +974,7 @@ class TFDebertaV2LMPredictionHead(tf.keras.layers.Layer):
def call(self, hidden_states: tf.Tensor) -> tf.Tensor: def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.transform(hidden_states=hidden_states) hidden_states = self.transform(hidden_states=hidden_states)
seq_length = shape_list(hidden_states)[1] seq_length = shape_list(hidden_states)[1]
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment