Unverified Commit 1c1e377e authored by Johannes Kolbe's avatar Johannes Kolbe Committed by GitHub
Browse files

TF - add unpack_inputs decorator for marian (#16226)



* add unpack_inputs decorator

* small fix for attn_mask string
Co-authored-by: default avatarJohannes Kolbe <johannes.kolbe@tech.better.team>
parent 81643edd
...@@ -43,8 +43,8 @@ from ...modeling_tf_utils import ( ...@@ -43,8 +43,8 @@ from ...modeling_tf_utils import (
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings, TFWrappedEmbeddings,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
...@@ -690,6 +690,7 @@ class TFMarianEncoder(tf.keras.layers.Layer): ...@@ -690,6 +690,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -744,79 +745,66 @@ class TFMarianEncoder(tf.keras.layers.Layer): ...@@ -744,79 +745,66 @@ class TFMarianEncoder(tf.keras.layers.Layer):
Whether or not to use the model in training mode (some modules like dropout modules have different Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation). behaviors between training and evaluation).
""" """
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["inputs_embeds"] is None: if inputs_embeds is None:
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input_shape)
hidden_states = inputs["inputs_embeds"] + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=training)
# check attention mask and invert # check attention mask and invert
if inputs["attention_mask"] is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(inputs["attention_mask"]) attention_mask = _expand_mask(attention_mask)
else: else:
attention_mask = None attention_mask = None
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if output_hidden_states else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly(): if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
) )
# encoder layers # encoder layers
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer if training and (dropout_probability < self.layerdrop): # skip the layer
continue continue
hidden_states, attn = encoder_layer( hidden_states, attn = encoder_layer(
hidden_states, hidden_states,
attention_mask, attention_mask,
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, head_mask[idx] if head_mask is not None else None,
) )
if inputs["output_attentions"]: if output_attentions:
all_attentions += (attn,) all_attentions += (attn,)
if inputs["output_hidden_states"]: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if not inputs["return_dict"]: if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return TFBaseModelOutput( return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
...@@ -856,6 +844,7 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -856,6 +844,7 @@ class TFMarianDecoder(tf.keras.layers.Layer):
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -939,45 +928,25 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -939,45 +928,25 @@ class TFMarianDecoder(tf.keras.layers.Layer):
Whether or not to use the model in training mode (some modules like dropout modules have different Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation). behaviors between training and evaluation).
""" """
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
past_key_values_length = ( past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
shape_list(inputs["past_key_values"][0][0])[2] if inputs["past_key_values"] is not None else 0
)
# embed positions # embed positions
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input_shape, past_key_values_length)
if inputs["inputs_embeds"] is None: if inputs_embeds is None:
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
hidden_states = inputs["inputs_embeds"] hidden_states = inputs_embeds
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
if input_shape[-1] > 1: if input_shape[-1] > 1:
...@@ -987,70 +956,66 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -987,70 +956,66 @@ class TFMarianDecoder(tf.keras.layers.Layer):
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
) )
if inputs["attention_mask"] is not None: if attention_mask is not None:
combined_attention_mask = combined_attention_mask + _expand_mask( combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
inputs["attention_mask"], tgt_len=input_shape[-1]
)
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1]) encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1])
hidden_states = self.dropout(hidden_states + positions, training=inputs["training"]) hidden_states = self.dropout(hidden_states + positions, training=training)
# decoder layers # decoder layers
all_hidden_states = () if inputs["output_hidden_states"] else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if inputs["output_attentions"] else None all_self_attns = () if output_attentions else None
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None
present_key_values = () if inputs["use_cache"] else None present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
for attn_mask in ["head_mask", "cross_attn_head_mask"]: for attn_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if inputs[attn_mask] is not None and tf.executing_eagerly(): if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs[attn_mask])[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.", message=f"The {attn_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop): if training and (dropout_probability < self.layerdrop):
continue continue
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=encoder_attention_mask,
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=head_mask[idx] if head_mask is not None else None,
cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx] cross_attn_layer_head_mask=cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
if inputs["cross_attn_head_mask"] is not None
else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
if inputs["use_cache"]: if use_cache:
present_key_values += (present_key_value,) present_key_values += (present_key_value,)
if inputs["output_attentions"]: if output_attentions:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None: if encoder_hidden_states is not None:
all_cross_attns += (layer_cross_attn,) all_cross_attns += (layer_cross_attn,)
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if not inputs["return_dict"]: if not return_dict:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPastAndCrossAttentions( return TFBaseModelOutputWithPastAndCrossAttentions(
...@@ -1097,6 +1062,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer): ...@@ -1097,6 +1062,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
self.encoder.set_embed_tokens(embed_tokens) self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens) self.decoder.set_embed_tokens(embed_tokens)
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -1117,77 +1083,54 @@ class TFMarianMainLayer(tf.keras.layers.Layer): ...@@ -1117,77 +1083,54 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
training=False, training=False,
**kwargs **kwargs
): ):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
inputs["use_cache"] = False use_cache = False
inputs["output_hidden_states"] = ( output_hidden_states = (
inputs["output_hidden_states"] output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
) )
if inputs["encoder_outputs"] is None: if encoder_outputs is None:
inputs["encoder_outputs"] = self.encoder( encoder_outputs = self.encoder(
input_ids=inputs["input_ids"], input_ids=input_ids,
attention_mask=inputs["attention_mask"], attention_mask=attention_mask,
head_mask=inputs["head_mask"], head_mask=head_mask,
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs_embeds,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
elif inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], TFBaseModelOutput): elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput):
inputs["encoder_outputs"] = TFBaseModelOutput( encoder_outputs = TFBaseModelOutput(
last_hidden_state=inputs["encoder_outputs"][0], last_hidden_state=encoder_outputs[0],
hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None, hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
) )
# If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
elif not inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], tuple): elif not return_dict and not isinstance(encoder_outputs, tuple):
inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple() encoder_outputs = encoder_outputs.to_tuple()
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
inputs["decoder_input_ids"], decoder_input_ids,
attention_mask=inputs["decoder_attention_mask"], attention_mask=decoder_attention_mask,
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=attention_mask,
head_mask=inputs["decoder_head_mask"], head_mask=decoder_head_mask,
cross_attn_head_mask=inputs["cross_attn_head_mask"], cross_attn_head_mask=cross_attn_head_mask,
past_key_values=inputs["past_key_values"], past_key_values=past_key_values,
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=decoder_inputs_embeds,
use_cache=inputs["use_cache"], use_cache=use_cache,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
if not inputs["return_dict"]: if not return_dict:
return decoder_outputs + inputs["encoder_outputs"] return decoder_outputs + encoder_outputs
return TFSeq2SeqModelOutput( return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state, last_hidden_state=decoder_outputs.last_hidden_state,
...@@ -1195,9 +1138,9 @@ class TFMarianMainLayer(tf.keras.layers.Layer): ...@@ -1195,9 +1138,9 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions, cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=encoder_outputs.attentions,
) )
...@@ -1217,6 +1160,7 @@ class TFMarianModel(TFMarianPreTrainedModel): ...@@ -1217,6 +1160,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
def get_decoder(self): def get_decoder(self):
return self.model.decoder return self.model.decoder
@unpack_inputs
@add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1244,16 +1188,14 @@ class TFMarianModel(TFMarianPreTrainedModel): ...@@ -1244,16 +1188,14 @@ class TFMarianModel(TFMarianPreTrainedModel):
training=False, training=False,
**kwargs **kwargs
): ):
inputs = input_processing( outputs = self.model(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask, cross_attn_head_mask=cross_attn_head_mask,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1263,26 +1205,6 @@ class TFMarianModel(TFMarianPreTrainedModel): ...@@ -1263,26 +1205,6 @@ class TFMarianModel(TFMarianPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
...@@ -1345,6 +1267,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1345,6 +1267,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
def set_bias(self, value): def set_bias(self, value):
self.final_logits_bias = value["final_logits_bias"] self.final_logits_bias = value["final_logits_bias"]
@unpack_inputs
@add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_end_docstrings(MARIAN_GENERATION_EXAMPLE) @add_end_docstrings(MARIAN_GENERATION_EXAMPLE)
...@@ -1378,17 +1301,28 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1378,17 +1301,28 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
Returns: Returns:
""" """
inputs = input_processing(
func=self.call, if labels is not None:
config=self.config, labels = tf.where(
input_ids=input_ids, labels == self.config.pad_token_id,
tf.fill(shape_list(labels), tf.cast(-100, labels.dtype)),
labels,
)
use_cache = False
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask, cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -1396,46 +1330,13 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1396,46 +1330,13 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
if inputs["labels"] is not None:
inputs["labels"] = tf.where(
inputs["labels"] == self.config.pad_token_id,
tf.fill(shape_list(inputs["labels"]), tf.cast(-100, inputs["labels"].dtype)),
inputs["labels"],
)
inputs["use_cache"] = False
if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = shift_tokens_right(
inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
lm_logits = self.model.shared(outputs[0], mode="linear") lm_logits = self.model.shared(outputs[0], mode="linear")
lm_logits = lm_logits + self.final_logits_bias lm_logits = lm_logits + self.final_logits_bias
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits) masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not inputs["return_dict"]: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return TFSeq2SeqLMOutput( return TFSeq2SeqLMOutput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment