Unverified Commit ddbc9ae0 authored by Louis Owen's avatar Louis Owen Committed by GitHub
Browse files

Update XLM with TF decorator (#16247)



* update XLM with tf decorator

* move to top decorator

* set unpack_inputs as top decorator
Co-authored-by: default avatarLouis Owen <yellow@Louis-Owen.local>
parent a6271967
...@@ -48,8 +48,8 @@ from ...modeling_tf_utils import ( ...@@ -48,8 +48,8 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings, TFSharedEmbeddings,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
...@@ -344,6 +344,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -344,6 +344,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -362,49 +363,31 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -362,49 +363,31 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
**kwargs, **kwargs,
): ):
# removed: src_enc=None, src_len=None # removed: src_enc=None, src_len=None
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
bs, slen = shape_list(inputs["input_ids"]) bs, slen = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
bs, slen = shape_list(inputs["inputs_embeds"])[:2] bs, slen = shape_list(inputs_embeds)[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["lengths"] is None: if lengths is None:
if inputs["input_ids"] is not None: if input_ids is not None:
inputs["lengths"] = tf.reduce_sum( lengths = tf.reduce_sum(
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=inputs["input_ids"].dtype), axis=1 tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=input_ids.dtype), axis=1
) )
else: else:
inputs["lengths"] = tf.convert_to_tensor([slen] * bs) lengths = tf.convert_to_tensor([slen] * bs)
# mask = input_ids != self.pad_index # mask = input_ids != self.pad_index
# check inputs # check inputs
# assert shape_list(lengths)[0] == bs # assert shape_list(lengths)[0] == bs
if tf.executing_eagerly(): if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["lengths"])[0], bs shape_list(lengths)[0], bs
), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched" ), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen # assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None) # assert (src_enc is None) == (src_len is None)
...@@ -413,28 +396,28 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -413,28 +396,28 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# assert src_enc.size(0) == bs # assert src_enc.size(0) == bs
# generate masks # generate masks
mask, attn_mask = get_masks(slen, inputs["lengths"], self.causal, padding_mask=inputs["attention_mask"]) mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
# if self.is_decoder and src_enc is not None: # if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# position_ids # position_ids
if inputs["position_ids"] is None: if position_ids is None:
inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0) position_ids = tf.expand_dims(tf.range(slen), axis=0)
inputs["position_ids"] = tf.tile(inputs["position_ids"], (bs, 1)) position_ids = tf.tile(position_ids, (bs, 1))
if tf.executing_eagerly(): if tf.executing_eagerly():
# assert shape_list(position_ids) == [bs, slen] # (slen, bs) # assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["position_ids"]), [bs, slen] shape_list(position_ids), [bs, slen]
), f"Position id shape {shape_list(inputs['position_ids'])} and input shape {[bs, slen]} mismatched" ), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1) # position_ids = position_ids.transpose(0, 1)
# langs # langs
if inputs["langs"] is not None and tf.executing_eagerly(): if langs is not None and tf.executing_eagerly():
# assert shape_list(langs) == [bs, slen] # (slen, bs) # assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["langs"]), [bs, slen] shape_list(langs), [bs, slen]
), f"Lang shape {shape_list(inputs['langs'])} and input shape {[bs, slen]} mismatched" ), f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched"
# langs = langs.transpose(0, 1) # langs = langs.transpose(0, 1)
# Prepare head mask if needed # Prepare head mask if needed
...@@ -442,43 +425,43 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -442,43 +425,43 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
if inputs["head_mask"] is not None: if head_mask is not None:
raise NotImplementedError raise NotImplementedError
else: else:
inputs["head_mask"] = [None] * self.n_layers head_mask = [None] * self.n_layers
# do not recompute cached elements # do not recompute cached elements
if inputs["cache"] is not None and inputs["input_ids"] is not None: if cache is not None and input_ids is not None:
_slen = slen - inputs["cache"]["slen"] _slen = slen - cache["slen"]
inputs["input_ids"] = inputs["input_ids"][:, -_slen:] input_ids = input_ids[:, -_slen:]
inputs["position_ids"] = inputs["position_ids"][:, -_slen:] position_ids = position_ids[:, -_slen:]
if inputs["langs"] is not None: if langs is not None:
inputs["langs"] = inputs["langs"][:, -_slen:] langs = langs[:, -_slen:]
mask = mask[:, -_slen:] mask = mask[:, -_slen:]
attn_mask = attn_mask[:, -_slen:] attn_mask = attn_mask[:, -_slen:]
# embeddings # embeddings
if inputs["inputs_embeds"] is None: if inputs_embeds is None:
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"]) inputs_embeds = self.embeddings(input_ids)
tensor = inputs["inputs_embeds"] + tf.gather(self.position_embeddings, inputs["position_ids"]) tensor = inputs_embeds + tf.gather(self.position_embeddings, position_ids)
if inputs["langs"] is not None and self.use_lang_emb and self.n_langs > 1: if langs is not None and self.use_lang_emb and self.n_langs > 1:
tensor = tensor + tf.gather(self.lang_embeddings, inputs["langs"]) tensor = tensor + tf.gather(self.lang_embeddings, langs)
if inputs["token_type_ids"] is not None: if token_type_ids is not None:
tensor = tensor + self.embeddings(inputs["token_type_ids"]) tensor = tensor + self.embeddings(token_type_ids)
tensor = self.layer_norm_emb(tensor) tensor = self.layer_norm_emb(tensor)
tensor = self.dropout(tensor, training=inputs["training"]) tensor = self.dropout(tensor, training=training)
mask = tf.cast(mask, dtype=tensor.dtype) mask = tf.cast(mask, dtype=tensor.dtype)
tensor = tensor * tf.expand_dims(mask, axis=-1) tensor = tensor * tf.expand_dims(mask, axis=-1)
# transformer layers # transformer layers
hidden_states = () if inputs["output_hidden_states"] else None hidden_states = () if output_hidden_states else None
attentions = () if inputs["output_attentions"] else None attentions = () if output_attentions else None
for i in range(self.n_layers): for i in range(self.n_layers):
if inputs["output_hidden_states"]: if output_hidden_states:
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
# self attention # self attention
...@@ -486,17 +469,17 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -486,17 +469,17 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
tensor, tensor,
attn_mask, attn_mask,
None, None,
inputs["cache"], cache,
inputs["head_mask"][i], head_mask[i],
inputs["output_attentions"], output_attentions,
training=inputs["training"], training=training,
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if inputs["output_attentions"]: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=inputs["training"]) attn = self.dropout(attn, training=training)
tensor = tensor + attn tensor = tensor + attn
tensor = self.layer_norm1[i](tensor) tensor = self.layer_norm1[i](tensor)
...@@ -513,17 +496,17 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -513,17 +496,17 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
tensor = tensor * tf.expand_dims(mask, axis=-1) tensor = tensor * tf.expand_dims(mask, axis=-1)
# Add last hidden state # Add last hidden state
if inputs["output_hidden_states"]: if output_hidden_states:
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
# update cache length # update cache length
if inputs["cache"] is not None: if cache is not None:
inputs["cache"]["slen"] += tensor.size(1) cache["slen"] += tensor.size(1)
# move back sequence length to dimension 0 # move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1) # tensor = tensor.transpose(0, 1)
if not inputs["return_dict"]: if not return_dict:
return tuple(v for v in [tensor, hidden_states, attentions] if v is not None) return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions) return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
...@@ -701,6 +684,7 @@ class TFXLMModel(TFXLMPreTrainedModel): ...@@ -701,6 +684,7 @@ class TFXLMModel(TFXLMPreTrainedModel):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.transformer = TFXLMMainLayer(config, name="transformer") self.transformer = TFXLMMainLayer(config, name="transformer")
@unpack_inputs
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -725,9 +709,7 @@ class TFXLMModel(TFXLMPreTrainedModel): ...@@ -725,9 +709,7 @@ class TFXLMModel(TFXLMPreTrainedModel):
training=False, training=False,
**kwargs, **kwargs,
): ):
inputs = input_processing( outputs = self.transformer(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
...@@ -741,22 +723,6 @@ class TFXLMModel(TFXLMPreTrainedModel): ...@@ -741,22 +723,6 @@ class TFXLMModel(TFXLMPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
...@@ -854,6 +820,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): ...@@ -854,6 +820,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
langs = None langs = None
return {"input_ids": inputs, "langs": langs} return {"input_ids": inputs, "langs": langs}
@unpack_inputs
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -878,9 +845,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): ...@@ -878,9 +845,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
training=False, training=False,
**kwargs, **kwargs,
): ):
inputs = input_processing( transformer_outputs = self.transformer(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
...@@ -894,28 +859,12 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): ...@@ -894,28 +859,12 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
outputs = self.pred_layer(output) outputs = self.pred_layer(output)
if not inputs["return_dict"]: if not return_dict:
return (outputs,) + transformer_outputs[1:] return (outputs,) + transformer_outputs[1:]
return TFXLMWithLMHeadModelOutput( return TFXLMWithLMHeadModelOutput(
...@@ -944,6 +893,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat ...@@ -944,6 +893,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
self.transformer = TFXLMMainLayer(config, name="transformer") self.transformer = TFXLMMainLayer(config, name="transformer")
self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary") self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary")
@unpack_inputs
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -975,9 +925,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat ...@@ -975,9 +925,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
inputs = input_processing( transformer_outputs = self.transformer(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
...@@ -990,32 +938,15 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat ...@@ -990,32 +938,15 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.sequence_summary(output) logits = self.sequence_summary(output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1070,6 +1001,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1070,6 +1001,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS), "input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
} }
@unpack_inputs
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1095,56 +1027,30 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1095,56 +1027,30 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
training=False, training=False,
**kwargs, **kwargs,
): ):
inputs = input_processing( if input_ids is not None:
func=self.call, num_choices = shape_list(input_ids)[1]
config=self.config, seq_length = shape_list(input_ids)[2]
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else: else:
num_choices = shape_list(inputs["inputs_embeds"])[1] num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs["inputs_embeds"])[2] seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = ( flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
) flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_token_type_ids = ( flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_langs = tf.reshape(inputs["langs"], (-1, seq_length)) if inputs["langs"] is not None else None
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3])) tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs["inputs_embeds"] is not None if inputs_embeds is not None
else None else None
) )
if inputs["lengths"] is not None: if lengths is not None:
logger.warning( logger.warning(
"The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the " "The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the "
"attention mask instead.", "attention mask instead.",
) )
inputs["lengths"] = None lengths = None
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
flat_input_ids, flat_input_ids,
...@@ -1152,23 +1058,23 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1152,23 +1058,23 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
flat_langs, flat_langs,
flat_token_type_ids, flat_token_type_ids,
flat_position_ids, flat_position_ids,
inputs["lengths"], lengths,
inputs["cache"], cache,
inputs["head_mask"], head_mask,
flat_inputs_embeds, flat_inputs_embeds,
inputs["output_attentions"], output_attentions,
inputs["output_hidden_states"], output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.sequence_summary(output) logits = self.sequence_summary(output)
logits = self.logits_proj(logits) logits = self.logits_proj(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits) loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
if not inputs["return_dict"]: if not return_dict:
output = (reshaped_logits,) + transformer_outputs[1:] output = (reshaped_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1220,6 +1126,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos ...@@ -1220,6 +1126,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
config.num_labels, kernel_initializer=get_initializer(config.init_std), name="classifier" config.num_labels, kernel_initializer=get_initializer(config.init_std), name="classifier"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1249,10 +1156,8 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos ...@@ -1249,10 +1156,8 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
""" """
inputs = input_processing( transformer_outputs = self.transformer(
func=self.call,
input_ids=input_ids, input_ids=input_ids,
config=self.config,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1264,33 +1169,16 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos ...@@ -1264,33 +1169,16 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = transformer_outputs[0] sequence_output = transformer_outputs[0]
sequence_output = self.dropout(sequence_output, training=inputs["training"]) sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1324,6 +1212,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1324,6 +1212,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
config.num_labels, kernel_initializer=get_initializer(config.init_std), name="qa_outputs" config.num_labels, kernel_initializer=get_initializer(config.init_std), name="qa_outputs"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1360,9 +1249,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1360,9 +1249,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss. are not taken into account for computing the loss.
""" """
inputs = input_processing( transformer_outputs = self.transformer(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
...@@ -1375,25 +1262,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1375,25 +1262,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = transformer_outputs[0] sequence_output = transformer_outputs[0]
...@@ -1403,12 +1272,12 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1403,12 +1272,12 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if inputs["start_positions"] is not None and inputs["end_positions"] is not None: if start_positions is not None and end_positions is not None:
labels = {"start_position": inputs["start_positions"]} labels = {"start_position": start_positions}
labels["end_position"] = inputs["end_positions"] labels["end_position"] = end_positions
loss = self.hf_compute_loss(labels, (start_logits, end_logits)) loss = self.hf_compute_loss(labels, (start_logits, end_logits))
if not inputs["return_dict"]: if not return_dict:
output = (start_logits, end_logits) + transformer_outputs[1:] output = (start_logits, end_logits) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment