"vscode:/vscode.git/clone" did not exist on "99f6e42113a374ff8b9b2fe535113f4f5bd12283"
Unverified Commit 2a6fbe6a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[XLNet] Fix mems behavior (#8567)

* fix mems in xlnet

* fix use_mems

* fix use_mem_len

* fix use mems

* clean docs

* fix tf typo

* make xlnet tf for generation work

* fix tf test

* refactor use cache

* add use cache for missing models

* correct use_cache in generate

* correct use cache in tf generate

* fix tf

* correct getattr typo

* make sylvain happy

* change in docs as well

* do not apply to cookie cutter statements

* fix tf test

* make pytorch model fully backward compatible
parent 369f1d77
...@@ -69,6 +69,8 @@ class T5Config(PretrainedConfig): ...@@ -69,6 +69,8 @@ class T5Config(PretrainedConfig):
feed_forward_proj (:obj:`string`, `optional`, defaults to :obj:`"relu"`): feed_forward_proj (:obj:`string`, `optional`, defaults to :obj:`"relu"`):
Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`. T5v1.1 uses Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`. T5v1.1 uses
the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`. the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
""" """
model_type = "t5" model_type = "t5"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
...@@ -88,6 +90,7 @@ class T5Config(PretrainedConfig): ...@@ -88,6 +90,7 @@ class T5Config(PretrainedConfig):
initializer_factor=1.0, initializer_factor=1.0,
feed_forward_proj="relu", feed_forward_proj="relu",
is_encoder_decoder=True, is_encoder_decoder=True,
use_cache=True,
pad_token_id=0, pad_token_id=0,
eos_token_id=1, eos_token_id=1,
**kwargs **kwargs
...@@ -112,6 +115,7 @@ class T5Config(PretrainedConfig): ...@@ -112,6 +115,7 @@ class T5Config(PretrainedConfig):
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache
@property @property
def hidden_size(self): def hidden_size(self):
......
...@@ -884,7 +884,7 @@ T5_INPUTS_DOCSTRING = r""" ...@@ -884,7 +884,7 @@ T5_INPUTS_DOCSTRING = r"""
:func:`transformers.PreTrainedTokenizer.__call__` and :func:`transformers.PreTrainedTokenizer.encode` for :func:`transformers.PreTrainedTokenizer.__call__` and :func:`transformers.PreTrainedTokenizer.encode` for
details. details.
To know more on how to prepare :obj:`inputs` for pre-training take a look at `T5 Training To know more on how to prepare :obj:`inputs` for pretraining take a look at `T5 Training
<./t5.html#training>`__. <./t5.html#training>`__.
decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Provide for sequence to sequence training. T5 uses the :obj:`pad_token_id` as the starting token for Provide for sequence to sequence training. T5 uses the :obj:`pad_token_id` as the starting token for
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
""" XLNet configuration """ """ XLNet configuration """
import warnings
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
...@@ -106,12 +108,18 @@ class XLNetConfig(PretrainedConfig): ...@@ -106,12 +108,18 @@ class XLNetConfig(PretrainedConfig):
Used in the SQuAD evaluation script. Used in the SQuAD evaluation script.
end_n_top (:obj:`int`, `optional`, defaults to 5): end_n_top (:obj:`int`, `optional`, defaults to 5):
Used in the SQuAD evaluation script. Used in the SQuAD evaluation script.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_mems_eval (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last pre-computed hidden states. Whether or not the model should make use of the recurrent memory mechanism in evaluation mode.
use_mems_train (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the model should make use of the recurrent memory mechanism in train mode.
.. note:: .. note::
This flag behaves differently from with other models: it just controls the inference behavior, during For pretraining, it is recommended to set ``use_mems_train`` to :obj:`True`. For fine-tuning, it is
training the model always uses ``use_cache=True``. recommended to set ``use_mems_train`` to :obj:`False` as discussed `here
<https://github.com/zihangdai/xlnet/issues/41#issuecomment-505102587>`__. If ``use_mems_train`` is set
to :obj:`True`, one has to make sure that the train batches are correctly pre-processed, `e.g.`
:obj:`batch_1 = [[This line is], [This is the]]` and :obj:`batch_2 = [[ the first line], [ second
line]]` and that all batches are of equal size.
Examples:: Examples::
...@@ -145,6 +153,8 @@ class XLNetConfig(PretrainedConfig): ...@@ -145,6 +153,8 @@ class XLNetConfig(PretrainedConfig):
dropout=0.1, dropout=0.1,
mem_len=512, mem_len=512,
reuse_len=None, reuse_len=None,
use_mems_eval=True,
use_mems_train=False,
bi_data=False, bi_data=False,
clamp_len=-1, clamp_len=-1,
same_length=False, same_length=False,
...@@ -197,6 +207,16 @@ class XLNetConfig(PretrainedConfig): ...@@ -197,6 +207,16 @@ class XLNetConfig(PretrainedConfig):
self.pad_token_id = pad_token_id self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
if "use_cache" in kwargs:
warnings.warn(
"The `use_cache` argument is deprecated and will be removed in a future version, use `use_mems_eval` instead.",
FutureWarning,
)
use_mems_eval = kwargs["use_cache"]
self.use_mems_eval = use_mems_eval
self.use_mems_train = use_mems_train
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):
return -1 return -1
......
...@@ -440,6 +440,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -440,6 +440,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
self.layer = [TFXLNetLayer(config, name="layer_._{}".format(i)) for i in range(config.n_layer)] self.layer = [TFXLNetLayer(config, name="layer_._{}".format(i)) for i in range(config.n_layer)]
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
self.use_mems_eval = config.use_mems_eval
self.use_mems_train = config.use_mems_train
def get_input_embeddings(self): def get_input_embeddings(self):
return self.word_embedding return self.word_embedding
...@@ -489,14 +492,23 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -489,14 +492,23 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
return ret return ret
def cache_mem(self, curr_out, prev_mem): def cache_mem(self, curr_out, prev_mem):
"""cache hidden states into memory.""" # cache hidden states into memory.
if self.reuse_len is not None and self.reuse_len > 0: if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[: self.reuse_len] curr_out = curr_out[: self.reuse_len]
if self.mem_len is None or self.mem_len == 0:
# If :obj:`use_mems` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time
# and returns all of the past and current hidden states.
cutoff = 0
else:
# If :obj:`use_mems` is active and `mem_len` is defined, the model returns the last `mem_len` hidden
# states. This is the preferred setting for training and long-form generation.
cutoff = -self.mem_len
if prev_mem is None: if prev_mem is None:
new_mem = curr_out[-self.mem_len :] # if :obj:`use_mems` is active and `mem_len` is defined, the model
new_mem = curr_out[cutoff:]
else: else:
new_mem = tf.concat([prev_mem, curr_out], 0)[-self.mem_len :] new_mem = tf.concat([prev_mem, curr_out], 0)[cutoff:]
return tf.stop_gradient(new_mem) return tf.stop_gradient(new_mem)
...@@ -569,7 +581,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -569,7 +581,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
input_mask=None, input_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_mems=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -587,7 +599,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -587,7 +599,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_mems=use_mems,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -602,6 +614,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -602,6 +614,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if training:
use_mems = use_mems if use_mems is not None else self.use_mems_train
else:
use_mems = use_mems if use_mems is not None else self.use_mems_eval
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension # but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end # so we move here the first dimension (batch) to the end
...@@ -737,7 +754,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -737,7 +754,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
hidden_states = [] if output_hidden_states else None hidden_states = [] if output_hidden_states else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
# cache new mems # cache new mems
if self.mem_len is not None and self.mem_len > 0 and use_cache: if use_mems:
new_mems = new_mems + (self.cache_mem(output_h, inputs["mems"][i]),) new_mems = new_mems + (self.cache_mem(output_h, inputs["mems"][i]),)
if output_hidden_states: if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
...@@ -768,7 +785,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -768,7 +785,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output = tf.transpose(output, perm=(1, 0, 2)) output = tf.transpose(output, perm=(1, 0, 2))
if not (self.mem_len is not None and self.mem_len > 0 and use_cache): if not use_mems:
new_mems = None new_mems = None
if output_hidden_states: if output_hidden_states:
if output_g is not None: if output_g is not None:
...@@ -1066,7 +1083,7 @@ XLNET_INPUTS_DOCSTRING = r""" ...@@ -1066,7 +1083,7 @@ XLNET_INPUTS_DOCSTRING = r"""
decoding. The token ids which have their past given to this model should not be passed as :obj:`input_ids` decoding. The token ids which have their past given to this model should not be passed as :obj:`input_ids`
as they have already been computed. as they have already been computed.
:obj::obj:`use_cache` has to be set to :obj:`True` to make use of :obj:`mems`. :obj::obj:`use_mems` has to be set to :obj:`True` to make use of :obj:`mems`.
perm_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, sequence_length)`, `optional`): perm_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, sequence_length)`, `optional`):
Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``: Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``:
...@@ -1147,7 +1164,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel): ...@@ -1147,7 +1164,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
input_mask=None, input_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_mems=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -1165,7 +1182,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel): ...@@ -1165,7 +1182,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_mems=use_mems,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -1182,7 +1199,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel): ...@@ -1182,7 +1199,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
input_mask=inputs["input_mask"], input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"], use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"], return_dict=inputs["return_dict"],
...@@ -1207,7 +1224,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1207,7 +1224,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_loss.input_embeddings return self.lm_loss.input_embeddings
def prepare_inputs_for_generation(self, inputs, past, **kwargs): def prepare_inputs_for_generation(self, inputs, past, use_mems=None, **kwargs):
# Add dummy token at the end (no attention on this one) # Add dummy token at the end (no attention on this one)
# At every pass, the attention values for the new token and the two last generated tokens # At every pass, the attention values for the new token and the two last generated tokens
...@@ -1238,7 +1255,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1238,7 +1255,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
"input_ids": inputs, "input_ids": inputs,
"perm_mask": perm_mask, "perm_mask": perm_mask,
"target_mapping": target_mapping, "target_mapping": target_mapping,
"use_cache": kwargs["use_cache"], "use_mems": kwargs.get("use_mems"),
} }
# if past is defined in model kwargs then use it for faster decoding # if past is defined in model kwargs then use it for faster decoding
...@@ -1260,7 +1277,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1260,7 +1277,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
input_mask=None, input_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_mems=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -1309,7 +1326,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1309,7 +1326,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_mems=use_mems,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -1328,7 +1345,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1328,7 +1345,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
input_mask=inputs["input_mask"], input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"], use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
...@@ -1395,7 +1412,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -1395,7 +1412,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
input_mask=None, input_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_mems=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -1420,7 +1437,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -1420,7 +1437,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_mems=use_mems,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -1439,7 +1456,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -1439,7 +1456,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
input_mask=inputs["input_mask"], input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"], use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
...@@ -1512,7 +1529,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1512,7 +1529,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
target_mapping=None, target_mapping=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_mems=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -1526,6 +1543,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1526,6 +1543,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above) :obj:`input_ids` above)
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
input_ids=input_ids, input_ids=input_ids,
...@@ -1537,7 +1555,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1537,7 +1555,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_mems=use_mems,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -1579,7 +1597,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1579,7 +1597,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
flat_input_mask, flat_input_mask,
inputs["head_mask"], inputs["head_mask"],
flat_inputs_embeds, flat_inputs_embeds,
inputs["use_cache"], inputs["use_mems"],
inputs["output_attentions"], inputs["output_attentions"],
inputs["output_hidden_states"], inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
...@@ -1639,7 +1657,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1639,7 +1657,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
input_mask=None, input_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_mems=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -1663,7 +1681,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1663,7 +1681,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_mems=use_mems,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -1682,7 +1700,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1682,7 +1700,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
input_mask=inputs["input_mask"], input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"], use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
...@@ -1739,7 +1757,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1739,7 +1757,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
input_mask=None, input_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_mems=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -1769,7 +1787,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1769,7 +1787,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_mems=use_mems,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -1789,7 +1807,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1789,7 +1807,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
input_mask=inputs["input_mask"], input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"], use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
""" """
PyTorch XLNet model. PyTorch XLNet model.
""" """
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -876,7 +877,7 @@ XLNET_INPUTS_DOCSTRING = r""" ...@@ -876,7 +877,7 @@ XLNET_INPUTS_DOCSTRING = r"""
decoding. The token ids which have their past given to this model should not be passed as :obj:`input_ids` decoding. The token ids which have their past given to this model should not be passed as :obj:`input_ids`
as they have already been computed. as they have already been computed.
:obj::obj:`use_cache` has to be set to :obj:`True` to make use of :obj:`mems`. :obj:`use_mems` has to be set to :obj:`True` to make use of :obj:`mems`.
perm_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, sequence_length)`, `optional`): perm_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, sequence_length)`, `optional`):
Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``: Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``:
...@@ -997,15 +998,15 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -997,15 +998,15 @@ class XLNetModel(XLNetPreTrainedModel):
curr_out = curr_out[: self.reuse_len] curr_out = curr_out[: self.reuse_len]
if self.mem_len is None or self.mem_len == 0: if self.mem_len is None or self.mem_len == 0:
# If :obj:`use_cache` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time # If :obj:`use_mems` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time
# and returns all of the past and current hidden states. # and returns all of the past and current hidden states.
cutoff = 0 cutoff = 0
else: else:
# If :obj:`use_cache` is active and `mem_len` is defined, the model returns the last `mem_len` hidden # If :obj:`use_mems` is active and `mem_len` is defined, the model returns the last `mem_len` hidden
# states. This is the preferred setting for training and long-form generation. # states. This is the preferred setting for training and long-form generation.
cutoff = -self.mem_len cutoff = -self.mem_len
if prev_mem is None: if prev_mem is None:
# if :obj:`use_cache` is active and `mem_len` is defined, the model # if :obj:`use_mems` is active and `mem_len` is defined, the model
new_mem = curr_out[cutoff:] new_mem = curr_out[cutoff:]
else: else:
new_mem = torch.cat([prev_mem, curr_out], dim=0)[cutoff:] new_mem = torch.cat([prev_mem, curr_out], dim=0)[cutoff:]
...@@ -1080,10 +1081,11 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -1080,10 +1081,11 @@ class XLNetModel(XLNetPreTrainedModel):
input_mask=None, input_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_mems=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
**kwargs, # delete after depreciation warning is removed
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
...@@ -1091,7 +1093,18 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -1091,7 +1093,18 @@ class XLNetModel(XLNetPreTrainedModel):
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
if "use_cache" in kwargs:
warnings.warn(
"The `use_cache` argument is deprecated and will be removed in a future version, use `use_mems` instead.",
FutureWarning,
)
use_mems = kwargs["use_cache"]
if self.training:
use_mems = use_mems if use_mems is not None else self.config.use_mems_train
else:
use_mems = use_mems if use_mems is not None else self.config.use_mems_eval
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension # but we want a unified interface in the library with the batch size on the first dimension
...@@ -1222,7 +1235,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -1222,7 +1235,7 @@ class XLNetModel(XLNetPreTrainedModel):
attentions = [] if output_attentions else None attentions = [] if output_attentions else None
hidden_states = [] if output_hidden_states else None hidden_states = [] if output_hidden_states else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if use_cache: if use_mems:
# cache new mems # cache new mems
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if output_hidden_states: if output_hidden_states:
...@@ -1253,7 +1266,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -1253,7 +1266,7 @@ class XLNetModel(XLNetPreTrainedModel):
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output = output.permute(1, 0, 2).contiguous() output = output.permute(1, 0, 2).contiguous()
if not use_cache: if not use_mems:
new_mems = None new_mems = None
if output_hidden_states: if output_hidden_states:
...@@ -1299,7 +1312,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1299,7 +1312,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_loss return self.lm_loss
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, use_mems=None, **kwargs):
# Add dummy token at the end (no attention on this one) # Add dummy token at the end (no attention on this one)
effective_batch_size = input_ids.shape[0] effective_batch_size = input_ids.shape[0]
...@@ -1332,7 +1345,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1332,7 +1345,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
"input_ids": input_ids, "input_ids": input_ids,
"perm_mask": perm_mask, "perm_mask": perm_mask,
"target_mapping": target_mapping, "target_mapping": target_mapping,
"use_cache": use_cache, "use_mems": use_mems,
} }
# if past is defined in model kwargs then use it for faster decoding # if past is defined in model kwargs then use it for faster decoding
...@@ -1355,10 +1368,11 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1355,10 +1368,11 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
use_cache=None, use_mems=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
**kwargs, # delete when `use_cache` is removed in XLNetModel
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_predict)`, `optional`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_predict)`, `optional`):
...@@ -1407,7 +1421,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1407,7 +1421,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
>>> next_token_logits = outputs.logits # Logits have shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size] >>> next_token_logits = outputs.logits # Logits have shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, input_ids,
...@@ -1419,10 +1432,11 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1419,10 +1432,11 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_mems=use_mems,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
**kwargs,
) )
logits = self.lm_loss(transformer_outputs[0]) logits = self.lm_loss(transformer_outputs[0])
...@@ -1483,10 +1497,11 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1483,10 +1497,11 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
use_cache=None, use_mems=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
**kwargs, # delete when `use_cache` is removed in XLNetModel
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
...@@ -1495,7 +1510,6 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1495,7 +1510,6 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, input_ids,
...@@ -1507,10 +1521,11 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1507,10 +1521,11 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_mems=use_mems,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
**kwargs,
) )
output = transformer_outputs[0] output = transformer_outputs[0]
...@@ -1576,10 +1591,11 @@ class XLNetForTokenClassification(XLNetPreTrainedModel): ...@@ -1576,10 +1591,11 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
use_cache=None, use_mems=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
**kwargs, # delete when `use_cache` is removed in XLNetModel
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
...@@ -1588,7 +1604,6 @@ class XLNetForTokenClassification(XLNetPreTrainedModel): ...@@ -1588,7 +1604,6 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
`input_ids` above) `input_ids` above)
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
outputs = self.transformer( outputs = self.transformer(
input_ids, input_ids,
...@@ -1600,7 +1615,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel): ...@@ -1600,7 +1615,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_mems=use_mems,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -1673,10 +1688,11 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): ...@@ -1673,10 +1688,11 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
use_cache=None, use_mems=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
**kwargs, # delete when `use_cache` is removed in XLNetModel
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
...@@ -1685,7 +1701,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): ...@@ -1685,7 +1701,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
:obj:`input_ids` above) :obj:`input_ids` above)
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
...@@ -1708,10 +1724,11 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): ...@@ -1708,10 +1724,11 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
target_mapping=target_mapping, target_mapping=target_mapping,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=flat_inputs_embeds, inputs_embeds=flat_inputs_embeds,
use_cache=use_cache, use_mems=use_mems,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
**kwargs,
) )
output = transformer_outputs[0] output = transformer_outputs[0]
...@@ -1775,10 +1792,11 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): ...@@ -1775,10 +1792,11 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
use_cache=None, use_mems=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
**kwargs, # delete when `use_cache` is removed in XLNetModel
): ):
r""" r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
...@@ -1791,7 +1809,6 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): ...@@ -1791,7 +1809,6 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
sequence are not taken into account for computing the loss. sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
outputs = self.transformer( outputs = self.transformer(
input_ids, input_ids,
...@@ -1803,10 +1820,11 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): ...@@ -1803,10 +1820,11 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_mems=use_mems,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
**kwargs,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -1885,10 +1903,11 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1885,10 +1903,11 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
is_impossible=None, is_impossible=None,
cls_index=None, cls_index=None,
p_mask=None, p_mask=None,
use_cache=None, use_mems=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
**kwargs, # delete when `use_cache` is removed in XLNetModel
): ):
r""" r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
...@@ -1926,7 +1945,6 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1926,7 +1945,6 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
>>> loss = outputs.loss >>> loss = outputs.loss
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, input_ids,
...@@ -1938,10 +1956,11 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1938,10 +1956,11 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_mems=use_mems,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
**kwargs,
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
start_logits = self.start_logits(hidden_states, p_mask=p_mask) start_logits = self.start_logits(hidden_states, p_mask=p_mask)
......
...@@ -153,7 +153,7 @@ class TFXLNetModelTester: ...@@ -153,7 +153,7 @@ class TFXLNetModelTester:
inputs = [input_ids_1, input_mask] inputs = [input_ids_1, input_mask]
result = model(inputs) result = model(inputs)
config.mem_len = 0 config.use_mems_eval = False
model = TFXLNetModel(config) model = TFXLNetModel(config)
no_mems_outputs = model(inputs) no_mems_outputs = model(inputs)
self.parent.assertEqual(len(no_mems_outputs), 1) self.parent.assertEqual(len(no_mems_outputs), 1)
......
...@@ -206,7 +206,36 @@ class XLNetModelTester: ...@@ -206,7 +206,36 @@ class XLNetModelTester:
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers, [(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
) )
def create_and_check_xlnet_model_use_cache( def create_and_check_use_mems_train(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetForSequenceClassification(config)
model.to(torch_device)
model.train()
train_size = input_ids_1.shape[0]
batch_size = 4
for i in range(train_size // batch_size + 1):
input_ids = input_ids_1[i : (i + 1) * batch_size]
labels = sequence_labels[i : (i + 1) * batch_size]
outputs = model(input_ids=input_ids, labels=labels, return_dict=True)
self.parent.assertIsNone(outputs.mems)
self.parent.assertIsNotNone(outputs.loss)
def create_and_check_xlnet_model_use_mems(
self, self,
config, config,
input_ids_1, input_ids_1,
...@@ -234,8 +263,8 @@ class XLNetModelTester: ...@@ -234,8 +263,8 @@ class XLNetModelTester:
device=torch_device, device=torch_device,
) )
causal_mask = torch.triu(causal_mask, diagonal=0) causal_mask = torch.triu(causal_mask, diagonal=0)
outputs_cache = model(input_ids_1, use_cache=True, perm_mask=causal_mask) outputs_cache = model(input_ids_1, use_mems=True, perm_mask=causal_mask)
outputs_no_cache = model(input_ids_1, use_cache=False, perm_mask=causal_mask) outputs_no_cache = model(input_ids_1, use_mems=False, perm_mask=causal_mask)
outputs_conf = model(input_ids_1) outputs_conf = model(input_ids_1)
self.parent.assertTrue(len(outputs_cache) == len(outputs_conf)) self.parent.assertTrue(len(outputs_cache) == len(outputs_conf))
...@@ -525,11 +554,15 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) ...@@ -525,11 +554,15 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs) self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
def test_xlnet_base_model_use_cache(self): def test_xlnet_base_model_use_mems(self):
# checking that in auto-regressive mode, :obj:`use_cache` gives the same results # checking that in auto-regressive mode, :obj:`use_mems` gives the same results
self.model_tester.set_seed() self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_model_use_cache(*config_and_inputs) self.model_tester.create_and_check_xlnet_model_use_mems(*config_and_inputs)
def test_seq_classification_use_mems_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_use_mems_train(*config_and_inputs)
def test_xlnet_base_model_with_att_output(self): def test_xlnet_base_model_with_att_output(self):
self.model_tester.set_seed() self.model_tester.set_seed()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment