Unverified Commit 8cc925a2 authored by Utku Saglam's avatar Utku Saglam Committed by GitHub
Browse files

TF clearer model variable naming: blenderbot (#16192)


Co-authored-by: default avatarutku saglam <utkusaglam@utku-MacBook-Pro.local>
parent 0f35cda4
...@@ -44,8 +44,8 @@ from ...modeling_tf_utils import ( ...@@ -44,8 +44,8 @@ from ...modeling_tf_utils import (
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings, TFWrappedEmbeddings,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
...@@ -645,6 +645,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer): ...@@ -645,6 +645,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -699,81 +700,67 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer): ...@@ -699,81 +700,67 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
Whether or not to use the model in training mode (some modules like dropout modules have different Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation). behaviors between training and evaluation).
""" """
inputs = input_processing( if input_ids is not None and inputs_embeds is not None:
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["inputs_embeds"] is None: if inputs_embeds is None:
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input_shape)
hidden_states = inputs["inputs_embeds"] + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=training)
# check attention mask and invert # check attention mask and invert
if inputs["attention_mask"] is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(inputs["attention_mask"]) attention_mask = _expand_mask(attention_mask)
else: else:
attention_mask = None attention_mask = None
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if output_hidden_states else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly(): if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
) )
# encoder layers # encoder layers
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer if training and (dropout_probability < self.layerdrop): # skip the layer
continue continue
hidden_states, attn = encoder_layer( hidden_states, attn = encoder_layer(
hidden_states, hidden_states,
attention_mask, attention_mask,
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, head_mask[idx] if head_mask is not None else None,
) )
if inputs["output_attentions"]: if output_attentions:
all_attentions += (attn,) all_attentions += (attn,)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
if inputs["output_hidden_states"]: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if not inputs["return_dict"]: if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return TFBaseModelOutput( return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
...@@ -814,6 +801,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -814,6 +801,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -897,45 +885,24 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -897,45 +885,24 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
Whether or not to use the model in training mode (some modules like dropout modules have different Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation). behaviors between training and evaluation).
""" """
inputs = input_processing( if input_ids is not None and inputs_embeds is not None:
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
past_key_values_length = ( past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
shape_list(inputs["past_key_values"][0][0])[2] if inputs["past_key_values"] is not None else 0
)
# embed positions # embed positions
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input_shape, past_key_values_length)
if inputs["inputs_embeds"] is None: if inputs_embeds is None:
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
hidden_states = inputs["inputs_embeds"] hidden_states = inputs_embeds
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
if input_shape[-1] > 1: if input_shape[-1] > 1:
...@@ -945,73 +912,68 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -945,73 +912,68 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
) )
if inputs["attention_mask"] is not None: if attention_mask is not None:
combined_attention_mask = combined_attention_mask + _expand_mask( combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
inputs["attention_mask"], tgt_len=input_shape[-1]
)
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1]) encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1])
hidden_states = hidden_states + positions hidden_states = hidden_states + positions
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=training)
# decoder layers # decoder layers
all_hidden_states = () if inputs["output_hidden_states"] else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if inputs["output_attentions"] else None all_self_attns = () if output_attentions else None
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None
present_key_values = () if inputs["use_cache"] else None present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
for attn_mask in ["head_mask", "cross_attn_head_mask"]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if inputs[attn_mask] is not None and tf.executing_eagerly(): if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs[attn_mask])[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.", message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop): if training and (dropout_probability < self.layerdrop):
continue continue
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=encoder_attention_mask,
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=head_mask[idx] if head_mask is not None else None,
cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx] cross_attn_layer_head_mask=cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
if inputs["cross_attn_head_mask"] is not None
else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
if inputs["use_cache"]: if use_cache:
present_key_values += (present_key_value,) present_key_values += (present_key_value,)
if inputs["output_attentions"]: if output_attentions:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None: if encoder_hidden_states is not None:
all_cross_attns += (layer_cross_attn,) all_cross_attns += (layer_cross_attn,)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if not inputs["return_dict"]: if not return_dict:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPastAndCrossAttentions( return TFBaseModelOutputWithPastAndCrossAttentions(
...@@ -1058,6 +1020,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer): ...@@ -1058,6 +1020,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
self.encoder.set_embed_tokens(embed_tokens) self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens) self.decoder.set_embed_tokens(embed_tokens)
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -1078,74 +1041,50 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer): ...@@ -1078,74 +1041,50 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
training=False, training=False,
**kwargs **kwargs
): ):
inputs = input_processing( output_hidden_states = (
func=self.call, output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
inputs["output_hidden_states"] = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
) )
if inputs["encoder_outputs"] is None: if encoder_outputs is None:
inputs["encoder_outputs"] = self.encoder( encoder_outputs = self.encoder(
input_ids=inputs["input_ids"], input_ids=input_ids,
attention_mask=inputs["attention_mask"], attention_mask=attention_mask,
head_mask=inputs["head_mask"], head_mask=head_mask,
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs_embeds,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
elif inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], TFBaseModelOutput): elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput):
inputs["encoder_outputs"] = TFBaseModelOutput( encoder_outputs = TFBaseModelOutput(
last_hidden_state=inputs["encoder_outputs"][0], last_hidden_state=encoder_outputs[0],
hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None, hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
) )
# If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
elif not inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], tuple): elif not return_dict and not isinstance(encoder_outputs, tuple):
inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple() encoder_outputs = encoder_outputs.to_tuple()
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
inputs["decoder_input_ids"], decoder_input_ids,
attention_mask=inputs["decoder_attention_mask"], attention_mask=decoder_attention_mask,
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=attention_mask,
head_mask=inputs["decoder_head_mask"], head_mask=decoder_head_mask,
cross_attn_head_mask=inputs["cross_attn_head_mask"], cross_attn_head_mask=cross_attn_head_mask,
past_key_values=inputs["past_key_values"], past_key_values=past_key_values,
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=decoder_inputs_embeds,
use_cache=inputs["use_cache"], use_cache=use_cache,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
if not inputs["return_dict"]: if not return_dict:
return decoder_outputs + inputs["encoder_outputs"] return decoder_outputs + encoder_outputs
return TFSeq2SeqModelOutput( return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state, last_hidden_state=decoder_outputs.last_hidden_state,
...@@ -1153,9 +1092,9 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer): ...@@ -1153,9 +1092,9 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions, cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=encoder_outputs.attentions,
) )
...@@ -1188,6 +1127,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): ...@@ -1188,6 +1127,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@unpack_inputs
@add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1215,9 +1155,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): ...@@ -1215,9 +1155,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
training=False, training=False,
**kwargs **kwargs
): ):
inputs = input_processing( outputs = self.model(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
...@@ -1234,26 +1172,6 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): ...@@ -1234,26 +1172,6 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
...@@ -1329,6 +1247,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal ...@@ -1329,6 +1247,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@unpack_inputs
@add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_end_docstrings(BLENDERBOT_GENERATION_EXAMPLE) @add_end_docstrings(BLENDERBOT_GENERATION_EXAMPLE)
...@@ -1362,17 +1281,27 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal ...@@ -1362,17 +1281,27 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
Returns: Returns:
""" """
inputs = input_processing( if labels is not None:
func=self.call, labels = tf.where(
config=self.config, labels == self.config.pad_token_id,
input_ids=input_ids, tf.fill(shape_list(labels), -100),
labels,
)
use_cache = False
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask, cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -1380,46 +1309,13 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal ...@@ -1380,46 +1309,13 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
if inputs["labels"] is not None:
inputs["labels"] = tf.where(
inputs["labels"] == self.config.pad_token_id,
tf.fill(shape_list(inputs["labels"]), -100),
inputs["labels"],
)
inputs["use_cache"] = False
if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = shift_tokens_right(
inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
lm_logits = self.model.shared(outputs[0], mode="linear") lm_logits = self.model.shared(outputs[0], mode="linear")
lm_logits = lm_logits + self.final_logits_bias lm_logits = lm_logits + self.final_logits_bias
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits) masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not inputs["return_dict"]: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return TFSeq2SeqLMOutput( return TFSeq2SeqLMOutput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment