"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "da7ea9a4e337eb2eed204090fe38198418c01134"
Unverified Commit 015de6f0 authored by Kamal Raj's avatar Kamal Raj Committed by GitHub
Browse files

TF clearer model variable naming: xlnet (#16150)

parent a23a7c0c
...@@ -42,8 +42,8 @@ from ...modeling_tf_utils import ( ...@@ -42,8 +42,8 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings, TFSharedEmbeddings,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
...@@ -578,6 +578,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -578,6 +578,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
return pos_emb return pos_emb
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -596,63 +597,34 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -596,63 +597,34 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
training=False, training=False,
**kwargs, **kwargs,
): ):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_mems=use_mems,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if training and inputs["use_mems"] is None: if training and use_mems is None:
inputs["use_mems"] = self.use_mems_train use_mems = self.use_mems_train
else: else:
inputs["use_mems"] = self.use_mems_eval use_mems = self.use_mems_eval
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension # but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end # so we move here the first dimension (batch) to the end
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
inputs["input_ids"] = tf.transpose(inputs["input_ids"], perm=(1, 0)) input_ids = tf.transpose(input_ids, perm=(1, 0))
qlen, bsz = shape_list(inputs["input_ids"])[:2] qlen, bsz = shape_list(input_ids)[:2]
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
inputs["inputs_embeds"] = tf.transpose(inputs["inputs_embeds"], perm=(1, 0, 2)) inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
qlen, bsz = shape_list(inputs["inputs_embeds"])[:2] qlen, bsz = shape_list(inputs_embeds)[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
inputs["token_type_ids"] = ( token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None
tf.transpose(inputs["token_type_ids"], perm=(1, 0)) if inputs["token_type_ids"] is not None else None input_mask = tf.transpose(input_mask, perm=(1, 0)) if input_mask is not None else None
) attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None
inputs["input_mask"] = ( perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None
tf.transpose(inputs["input_mask"], perm=(1, 0)) if inputs["input_mask"] is not None else None target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None
)
inputs["attention_mask"] = (
tf.transpose(inputs["attention_mask"], perm=(1, 0)) if inputs["attention_mask"] is not None else None
)
inputs["perm_mask"] = (
tf.transpose(inputs["perm_mask"], perm=(1, 2, 0)) if inputs["perm_mask"] is not None else None
)
inputs["target_mapping"] = (
tf.transpose(inputs["target_mapping"], perm=(1, 2, 0)) if inputs["target_mapping"] is not None else None
)
mlen = shape_list(inputs["mems"][0])[0] if inputs["mems"] is not None and inputs["mems"][0] is not None else 0 mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0
klen = mlen + qlen klen = mlen + qlen
# Attention mask # Attention mask
...@@ -666,19 +638,19 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -666,19 +638,19 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
raise ValueError(f"Unsupported attention type: {self.attn_type}") raise ValueError(f"Unsupported attention type: {self.attn_type}")
# data mask: input mask & perm mask # data mask: input mask & perm mask
assert inputs["input_mask"] is None or inputs["attention_mask"] is None, ( assert input_mask is None or attention_mask is None, (
"You can only use one of input_mask (uses 1 for padding) " "You can only use one of input_mask (uses 1 for padding) "
"or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one." "or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
) )
if inputs["input_mask"] is None and inputs["attention_mask"] is not None: if input_mask is None and attention_mask is not None:
one_cst = tf.constant(1.0) one_cst = tf.constant(1.0)
inputs["input_mask"] = 1.0 - tf.cast(inputs["attention_mask"], dtype=one_cst.dtype) input_mask = 1.0 - tf.cast(attention_mask, dtype=one_cst.dtype)
if inputs["input_mask"] is not None and inputs["perm_mask"] is not None: if input_mask is not None and perm_mask is not None:
data_mask = inputs["input_mask"][None] + inputs["perm_mask"] data_mask = input_mask[None] + perm_mask
elif inputs["input_mask"] is not None and inputs["perm_mask"] is None: elif input_mask is not None and perm_mask is None:
data_mask = inputs["input_mask"][None] data_mask = input_mask[None]
elif inputs["input_mask"] is None and inputs["perm_mask"] is not None: elif input_mask is None and perm_mask is not None:
data_mask = inputs["perm_mask"] data_mask = perm_mask
else: else:
data_mask = None data_mask = None
...@@ -704,33 +676,33 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -704,33 +676,33 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
non_tgt_mask = None non_tgt_mask = None
# Word embeddings and prepare h & g hidden states # Word embeddings and prepare h & g hidden states
if inputs["inputs_embeds"] is not None: if inputs_embeds is not None:
word_emb_k = inputs["inputs_embeds"] word_emb_k = inputs_embeds
else: else:
word_emb_k = self.word_embedding(inputs["input_ids"]) word_emb_k = self.word_embedding(input_ids)
output_h = self.dropout(word_emb_k, training=inputs["training"]) output_h = self.dropout(word_emb_k, training=training)
if inputs["target_mapping"] is not None: if target_mapping is not None:
word_emb_q = tf.tile(self.mask_emb, [shape_list(inputs["target_mapping"])[0], bsz, 1]) word_emb_q = tf.tile(self.mask_emb, [shape_list(target_mapping)[0], bsz, 1])
# else: # We removed the inp_q input which was same as target mapping # else: # We removed the inp_q input which was same as target mapping
# inp_q_ext = inp_q[:, :, None] # inp_q_ext = inp_q[:, :, None]
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k # word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
output_g = self.dropout(word_emb_q, training=inputs["training"]) output_g = self.dropout(word_emb_q, training=training)
else: else:
output_g = None output_g = None
# Segment embedding # Segment embedding
if inputs["token_type_ids"] is not None: if token_type_ids is not None:
# Convert `token_type_ids` to one-hot `seg_mat` # Convert `token_type_ids` to one-hot `seg_mat`
if mlen > 0: if mlen > 0:
mem_pad = tf.zeros([mlen, bsz], dtype=inputs["token_type_ids"].dtype) mem_pad = tf.zeros([mlen, bsz], dtype=token_type_ids.dtype)
cat_ids = tf.concat([mem_pad, inputs["token_type_ids"]], 0) cat_ids = tf.concat([mem_pad, token_type_ids], 0)
else: else:
cat_ids = inputs["token_type_ids"] cat_ids = token_type_ids
# `1` indicates not in the same segment [qlen x klen x bsz] # `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = tf.cast( seg_mat = tf.cast(
tf.logical_not(tf.equal(inputs["token_type_ids"][:, None], cat_ids[None, :])), tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])),
dtype=inputs["token_type_ids"].dtype, dtype=token_type_ids.dtype,
) )
seg_mat = tf.one_hot(seg_mat, 2) seg_mat = tf.one_hot(seg_mat, 2)
else: else:
...@@ -738,29 +710,29 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -738,29 +710,29 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# Positional encoding # Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz) pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = self.dropout(pos_emb, training=inputs["training"]) pos_emb = self.dropout(pos_emb, training=training)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if inputs["head_mask"] is not None: if head_mask is not None:
raise NotImplementedError raise NotImplementedError
else: else:
inputs["head_mask"] = [None] * self.n_layer head_mask = [None] * self.n_layer
new_mems = () new_mems = ()
if inputs["mems"] is None: if mems is None:
inputs["mems"] = [None] * len(self.layer) mems = [None] * len(self.layer)
attentions = [] if inputs["output_attentions"] else None attentions = [] if output_attentions else None
hidden_states = [] if inputs["output_hidden_states"] else None hidden_states = [] if output_hidden_states else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
# cache new mems # cache new mems
if inputs["use_mems"]: if use_mems:
new_mems = new_mems + (self.cache_mem(output_h, inputs["mems"][i]),) new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if inputs["output_hidden_states"]: if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
outputs = layer_module( outputs = layer_module(
...@@ -770,34 +742,34 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -770,34 +742,34 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
attn_mask, attn_mask,
pos_emb, pos_emb,
seg_mat, seg_mat,
inputs["mems"][i], mems[i],
inputs["target_mapping"], target_mapping,
inputs["head_mask"][i], head_mask[i],
inputs["output_attentions"], output_attentions,
training=inputs["training"], training=training,
) )
output_h, output_g = outputs[:2] output_h, output_g = outputs[:2]
if inputs["output_attentions"]: if output_attentions:
attentions.append(outputs[2]) attentions.append(outputs[2])
# Add last hidden state # Add last hidden state
if inputs["output_hidden_states"]: if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
output = self.dropout(output_g if output_g is not None else output_h, training=inputs["training"]) output = self.dropout(output_g if output_g is not None else output_h, training=training)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output = tf.transpose(output, perm=(1, 0, 2)) output = tf.transpose(output, perm=(1, 0, 2))
if not inputs["use_mems"]: if not use_mems:
new_mems = None new_mems = None
if inputs["output_hidden_states"]: if output_hidden_states:
if output_g is not None: if output_g is not None:
hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs) hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs)
else: else:
hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states) hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states)
if inputs["output_attentions"]: if output_attentions:
if inputs["target_mapping"] is not None: if target_mapping is not None:
# when target_mapping is provided, there are 2-tuple of attentions # when target_mapping is provided, there are 2-tuple of attentions
attentions = tuple( attentions = tuple(
tuple(tf.transpose(attn_stream, perm=(2, 3, 0, 1)) for attn_stream in t) for t in attentions tuple(tf.transpose(attn_stream, perm=(2, 3, 0, 1)) for attn_stream in t) for t in attentions
...@@ -805,7 +777,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -805,7 +777,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else: else:
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
if not inputs["return_dict"]: if not return_dict:
return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None) return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None)
return TFXLNetModelOutput( return TFXLNetModelOutput(
...@@ -1154,6 +1126,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel): ...@@ -1154,6 +1126,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.transformer = TFXLNetMainLayer(config, name="transformer") self.transformer = TFXLNetMainLayer(config, name="transformer")
@unpack_inputs
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1179,9 +1152,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel): ...@@ -1179,9 +1152,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
training=False, training=False,
**kwargs, **kwargs,
): ):
inputs = input_processing( outputs = self.transformer(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
...@@ -1196,23 +1167,6 @@ class TFXLNetModel(TFXLNetPreTrainedModel): ...@@ -1196,23 +1167,6 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
...@@ -1286,6 +1240,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1286,6 +1240,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
return inputs return inputs
@unpack_inputs
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFXLNetLMHeadModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFXLNetLMHeadModelOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -1349,9 +1304,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1349,9 +1304,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
... 0 ... 0
>>> ] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size] >>> ] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
```""" ```"""
inputs = input_processing( transformer_outputs = self.transformer(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
...@@ -1365,34 +1318,16 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1365,34 +1318,16 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
hidden_state = transformer_outputs[0] hidden_state = transformer_outputs[0]
logits = self.lm_loss(hidden_state, training=inputs["training"]) logits = self.lm_loss(hidden_state, training=training)
loss = None loss = None
if inputs["labels"] is not None: if labels is not None:
loss = self.hf_compute_loss(inputs["labels"], logits) loss = self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1432,6 +1367,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -1432,6 +1367,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1464,9 +1400,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -1464,9 +1400,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
inputs = input_processing( transformer_outputs = self.transformer(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
...@@ -1480,34 +1414,16 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -1480,34 +1414,16 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
output = self.sequence_summary(output) output = self.sequence_summary(output)
logits = self.logits_proj(output) logits = self.logits_proj(output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1558,6 +1474,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1558,6 +1474,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
""" """
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@unpack_inputs
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1590,72 +1507,45 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1590,72 +1507,45 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
""" """
inputs = input_processing( if input_ids is not None:
func=self.call, num_choices = shape_list(input_ids)[1]
config=self.config, seq_length = shape_list(input_ids)[2]
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_mems=use_mems,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else: else:
num_choices = shape_list(inputs["inputs_embeds"])[1] num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs["inputs_embeds"])[2] seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = ( flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
) flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_input_mask = (
tf.reshape(inputs["input_mask"], (-1, seq_length)) if inputs["input_mask"] is not None else None
)
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3])) tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs["inputs_embeds"] is not None if inputs_embeds is not None
else None else None
) )
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
flat_input_ids, flat_input_ids,
flat_attention_mask, flat_attention_mask,
inputs["mems"], mems,
inputs["perm_mask"], perm_mask,
inputs["target_mapping"], target_mapping,
flat_token_type_ids, flat_token_type_ids,
flat_input_mask, flat_input_mask,
inputs["head_mask"], head_mask,
flat_inputs_embeds, flat_inputs_embeds,
inputs["use_mems"], use_mems,
inputs["output_attentions"], output_attentions,
inputs["output_hidden_states"], output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.sequence_summary(output) logits = self.sequence_summary(output)
logits = self.logits_proj(logits) logits = self.logits_proj(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits) loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
if not inputs["return_dict"]: if not return_dict:
output = (reshaped_logits,) + transformer_outputs[1:] output = (reshaped_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1706,6 +1596,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1706,6 +1596,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1737,9 +1628,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1737,9 +1628,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
""" """
inputs = input_processing( transformer_outputs = self.transformer(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
...@@ -1753,31 +1642,13 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1753,31 +1642,13 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.classifier(output) logits = self.classifier(output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1812,6 +1683,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1812,6 +1683,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1849,9 +1721,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1849,9 +1721,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss. are not taken into account for computing the loss.
""" """
inputs = input_processing( transformer_outputs = self.transformer(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
...@@ -1865,26 +1735,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1865,26 +1735,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = transformer_outputs[0] sequence_output = transformer_outputs[0]
...@@ -1894,12 +1745,12 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1894,12 +1745,12 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if inputs["start_positions"] is not None and inputs["end_positions"] is not None: if start_positions is not None and end_positions is not None:
labels = {"start_position": inputs["start_positions"]} labels = {"start_position": start_positions}
labels["end_position"] = inputs["end_positions"] labels["end_position"] = end_positions
loss = self.hf_compute_loss(labels, (start_logits, end_logits)) loss = self.hf_compute_loss(labels, (start_logits, end_logits))
if not inputs["return_dict"]: if not return_dict:
output = (start_logits, end_logits) + transformer_outputs[1:] output = (start_logits, end_logits) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment