Unverified Commit 277fc2cc authored by Dan Tegzes's avatar Dan Tegzes Committed by GitHub
Browse files

Update flaubert with tf decorator (#16258)

parent 75c666b4
......@@ -37,8 +37,8 @@ from ...modeling_tf_utils import (
TFPreTrainedModel,
TFSharedEmbeddings,
get_initializer,
input_processing,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import logging
......@@ -235,6 +235,7 @@ class TFFlaubertModel(TFFlaubertPreTrainedModel):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFFlaubertMainLayer(config, name="transformer")
@unpack_inputs
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -259,9 +260,7 @@ class TFFlaubertModel(TFFlaubertPreTrainedModel):
training: Optional[bool] = False,
**kwargs,
) -> Union[Tuple, TFBaseModelOutput]:
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
......@@ -275,22 +274,6 @@ class TFFlaubertModel(TFFlaubertPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
......@@ -491,6 +474,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
self.embeddings.weight = value
self.embeddings.vocab_size = shape_list(value)[0]
@unpack_inputs
def call(
self,
input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
......@@ -509,49 +493,31 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
**kwargs,
) -> Union[Tuple, TFBaseModelOutput]:
# removed: src_enc=None, src_len=None
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
bs, slen = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
bs, slen = shape_list(inputs["inputs_embeds"])[:2]
elif input_ids is not None:
bs, slen = shape_list(input_ids)
elif inputs_embeds is not None:
bs, slen = shape_list(inputs_embeds)[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["lengths"] is None:
if inputs["input_ids"] is not None:
inputs["lengths"] = tf.reduce_sum(
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=inputs["input_ids"].dtype), axis=1
if lengths is None:
if input_ids is not None:
lengths = tf.reduce_sum(
tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=input_ids.dtype), axis=1
)
else:
inputs["lengths"] = tf.convert_to_tensor([slen] * bs)
lengths = tf.convert_to_tensor([slen] * bs)
# mask = input_ids != self.pad_index
# check inputs
# assert shape_list(lengths)[0] == bs
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["lengths"])[0], bs
), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched"
shape_list(lengths)[0], bs
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None)
......@@ -560,28 +526,28 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# assert src_enc.size(0) == bs
# generate masks
mask, attn_mask = get_masks(slen, inputs["lengths"], self.causal, padding_mask=inputs["attention_mask"])
mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
# if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# position_ids
if inputs["position_ids"] is None:
inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0)
inputs["position_ids"] = tf.tile(inputs["position_ids"], (bs, 1))
if position_ids is None:
position_ids = tf.expand_dims(tf.range(slen), axis=0)
position_ids = tf.tile(position_ids, (bs, 1))
if tf.executing_eagerly():
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(
shape_list(inputs["position_ids"]), [bs, slen]
), f"Position id shape {shape_list(inputs['position_ids'])} and input shape {[bs, slen]} mismatched"
shape_list(position_ids), [bs, slen]
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1)
# langs
if inputs["langs"] is not None and tf.executing_eagerly():
if langs is not None and tf.executing_eagerly():
# assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(
shape_list(inputs["langs"]), [bs, slen]
), f"Lang shape {shape_list(inputs['langs'])} and input shape {[bs, slen]} mismatched"
shape_list(langs), [bs, slen]
), f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched"
# langs = langs.transpose(0, 1)
# Prepare head mask if needed
......@@ -589,50 +555,50 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
if inputs["head_mask"] is not None:
if head_mask is not None:
raise NotImplementedError
else:
inputs["head_mask"] = [None] * self.n_layers
head_mask = [None] * self.n_layers
# do not recompute cached elements
if inputs["cache"] is not None and inputs["input_ids"] is not None:
_slen = slen - inputs["cache"]["slen"]
inputs["input_ids"] = inputs["input_ids"][:, -_slen:]
inputs["position_ids"] = inputs["position_ids"][:, -_slen:]
if inputs["langs"] is not None:
inputs["langs"] = inputs["langs"][:, -_slen:]
if cache is not None and input_ids is not None:
_slen = slen - cache["slen"]
input_ids = input_ids[:, -_slen:]
position_ids = position_ids[:, -_slen:]
if langs is not None:
langs = langs[:, -_slen:]
mask = mask[:, -_slen:]
attn_mask = attn_mask[:, -_slen:]
# embeddings
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"])
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
tensor = inputs["inputs_embeds"] + tf.gather(self.position_embeddings, inputs["position_ids"])
tensor = inputs_embeds + tf.gather(self.position_embeddings, position_ids)
if inputs["langs"] is not None and self.use_lang_emb:
tensor = tensor + tf.gather(self.lang_embeddings, inputs["langs"])
if inputs["token_type_ids"] is not None:
tensor = tensor + self.embeddings(inputs["token_type_ids"])
if langs is not None and self.use_lang_emb:
tensor = tensor + tf.gather(self.lang_embeddings, langs)
if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids)
tensor = self.layer_norm_emb(tensor)
tensor = self.dropout(tensor, training=inputs["training"])
tensor = self.dropout(tensor, training=training)
mask = tf.cast(mask, dtype=tensor.dtype)
tensor = tensor * tf.expand_dims(mask, axis=-1)
# hidden_states and attentions cannot be None in graph mode.
hidden_states = () if inputs["output_hidden_states"] else None
attentions = () if inputs["output_attentions"] else None
hidden_states = () if output_hidden_states else None
attentions = () if output_attentions else None
# transformer layers
for i in range(self.n_layers):
# LayerDrop
dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop):
if training and (dropout_probability < self.layerdrop):
continue
if inputs["output_hidden_states"]:
if output_hidden_states:
hidden_states = hidden_states + (tensor,)
# self attention
......@@ -641,17 +607,17 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
tensor,
attn_mask,
None,
inputs["cache"],
inputs["head_mask"][i],
inputs["output_attentions"],
training=inputs["training"],
cache,
head_mask[i],
output_attentions,
training=training,
)
attn = attn_outputs[0]
if inputs["output_attentions"]:
if output_attentions:
attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=inputs["training"])
attn = self.dropout(attn, training=training)
tensor = tensor + attn
tensor = self.layer_norm1[i](tensor)
else:
......@@ -660,17 +626,17 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
tensor_normalized,
attn_mask,
None,
inputs["cache"],
inputs["head_mask"][i],
inputs["output_attentions"],
training=inputs["training"],
cache,
head_mask[i],
output_attentions,
training=training,
)
attn = attn_outputs[0]
if inputs["output_attentions"]:
if output_attentions:
attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=inputs["training"])
attn = self.dropout(attn, training=training)
tensor = tensor + attn
# encoder attention (for decoder only)
......@@ -691,17 +657,17 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
tensor = tensor * tf.expand_dims(mask, axis=-1)
# Add last hidden state
if inputs["output_hidden_states"]:
if output_hidden_states:
hidden_states = hidden_states + (tensor,)
# update cache length
if inputs["cache"] is not None:
inputs["cache"]["slen"] += tensor.size(1)
if cache is not None:
cache["slen"] += tensor.size(1)
# move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1)
if not inputs["return_dict"]:
if not return_dict:
return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
......@@ -819,6 +785,7 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
langs = None
return {"input_ids": inputs, "langs": langs}
@unpack_inputs
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -843,9 +810,8 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
training: Optional[bool] = False,
**kwargs,
) -> Union[Tuple, TFFlaubertWithLMHeadModelOutput]:
inputs = input_processing(
func=self.call,
config=self.config,
transformer_outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
......@@ -859,27 +825,11 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
output = transformer_outputs[0]
outputs = self.pred_layer(output)
if not inputs["return_dict"]:
if not return_dict:
return (outputs,) + transformer_outputs[1:]
return TFFlaubertWithLMHeadModelOutput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment