"vscode:/vscode.git/clone" did not exist on "6a9e8f6aa02a8fdc395e999c521405c4a2768b45"
Unverified Commit 2a6fbe6a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[XLNet] Fix mems behavior (#8567)

* fix mems in xlnet

* fix use_mems

* fix use_mem_len

* fix use mems

* clean docs

* fix tf typo

* make xlnet tf for generation work

* fix tf test

* refactor use cache

* add use cache for missing models

* correct use_cache in generate

* correct use cache in tf generate

* fix tf

* correct getattr typo

* make sylvain happy

* change in docs as well

* do not apply to cookie cutter statements

* fix tf test

* make pytorch model fully backward compatible
parent 369f1d77
......@@ -69,6 +69,8 @@ class T5Config(PretrainedConfig):
feed_forward_proj (:obj:`string`, `optional`, defaults to :obj:`"relu"`):
Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`. T5v1.1 uses
the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
"""
model_type = "t5"
keys_to_ignore_at_inference = ["past_key_values"]
......@@ -88,6 +90,7 @@ class T5Config(PretrainedConfig):
initializer_factor=1.0,
feed_forward_proj="relu",
is_encoder_decoder=True,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
**kwargs
......@@ -112,6 +115,7 @@ class T5Config(PretrainedConfig):
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache
@property
def hidden_size(self):
......
......@@ -884,7 +884,7 @@ T5_INPUTS_DOCSTRING = r"""
:func:`transformers.PreTrainedTokenizer.__call__` and :func:`transformers.PreTrainedTokenizer.encode` for
details.
To know more on how to prepare :obj:`inputs` for pre-training take a look at `T5 Training
To know more on how to prepare :obj:`inputs` for pretraining take a look at `T5 Training
<./t5.html#training>`__.
decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Provide for sequence to sequence training. T5 uses the :obj:`pad_token_id` as the starting token for
......
......@@ -15,6 +15,8 @@
# limitations under the License.
""" XLNet configuration """
import warnings
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -106,12 +108,18 @@ class XLNetConfig(PretrainedConfig):
Used in the SQuAD evaluation script.
end_n_top (:obj:`int`, `optional`, defaults to 5):
Used in the SQuAD evaluation script.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last pre-computed hidden states.
use_mems_eval (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should make use of the recurrent memory mechanism in evaluation mode.
use_mems_train (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the model should make use of the recurrent memory mechanism in train mode.
.. note::
This flag behaves differently from with other models: it just controls the inference behavior, during
training the model always uses ``use_cache=True``.
For pretraining, it is recommended to set ``use_mems_train`` to :obj:`True`. For fine-tuning, it is
recommended to set ``use_mems_train`` to :obj:`False` as discussed `here
<https://github.com/zihangdai/xlnet/issues/41#issuecomment-505102587>`__. If ``use_mems_train`` is set
to :obj:`True`, one has to make sure that the train batches are correctly pre-processed, `e.g.`
:obj:`batch_1 = [[This line is], [This is the]]` and :obj:`batch_2 = [[ the first line], [ second
line]]` and that all batches are of equal size.
Examples::
......@@ -145,6 +153,8 @@ class XLNetConfig(PretrainedConfig):
dropout=0.1,
mem_len=512,
reuse_len=None,
use_mems_eval=True,
use_mems_train=False,
bi_data=False,
clamp_len=-1,
same_length=False,
......@@ -197,6 +207,16 @@ class XLNetConfig(PretrainedConfig):
self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id
if "use_cache" in kwargs:
warnings.warn(
"The `use_cache` argument is deprecated and will be removed in a future version, use `use_mems_eval` instead.",
FutureWarning,
)
use_mems_eval = kwargs["use_cache"]
self.use_mems_eval = use_mems_eval
self.use_mems_train = use_mems_train
@property
def max_position_embeddings(self):
return -1
......
......@@ -440,6 +440,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
self.layer = [TFXLNetLayer(config, name="layer_._{}".format(i)) for i in range(config.n_layer)]
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.use_mems_eval = config.use_mems_eval
self.use_mems_train = config.use_mems_train
def get_input_embeddings(self):
return self.word_embedding
......@@ -489,14 +492,23 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
return ret
def cache_mem(self, curr_out, prev_mem):
"""cache hidden states into memory."""
# cache hidden states into memory.
if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[: self.reuse_len]
if self.mem_len is None or self.mem_len == 0:
# If :obj:`use_mems` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time
# and returns all of the past and current hidden states.
cutoff = 0
else:
# If :obj:`use_mems` is active and `mem_len` is defined, the model returns the last `mem_len` hidden
# states. This is the preferred setting for training and long-form generation.
cutoff = -self.mem_len
if prev_mem is None:
new_mem = curr_out[-self.mem_len :]
# if :obj:`use_mems` is active and `mem_len` is defined, the model
new_mem = curr_out[cutoff:]
else:
new_mem = tf.concat([prev_mem, curr_out], 0)[-self.mem_len :]
new_mem = tf.concat([prev_mem, curr_out], 0)[cutoff:]
return tf.stop_gradient(new_mem)
......@@ -569,7 +581,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
......@@ -587,7 +599,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
use_mems=use_mems,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
......@@ -602,6 +614,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if training:
use_mems = use_mems if use_mems is not None else self.use_mems_train
else:
use_mems = use_mems if use_mems is not None else self.use_mems_eval
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
......@@ -737,7 +754,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
hidden_states = [] if output_hidden_states else None
for i, layer_module in enumerate(self.layer):
# cache new mems
if self.mem_len is not None and self.mem_len > 0 and use_cache:
if use_mems:
new_mems = new_mems + (self.cache_mem(output_h, inputs["mems"][i]),)
if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
......@@ -768,7 +785,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output = tf.transpose(output, perm=(1, 0, 2))
if not (self.mem_len is not None and self.mem_len > 0 and use_cache):
if not use_mems:
new_mems = None
if output_hidden_states:
if output_g is not None:
......@@ -1066,7 +1083,7 @@ XLNET_INPUTS_DOCSTRING = r"""
decoding. The token ids which have their past given to this model should not be passed as :obj:`input_ids`
as they have already been computed.
:obj::obj:`use_cache` has to be set to :obj:`True` to make use of :obj:`mems`.
:obj::obj:`use_mems` has to be set to :obj:`True` to make use of :obj:`mems`.
perm_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, sequence_length)`, `optional`):
Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``:
......@@ -1147,7 +1164,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
......@@ -1165,7 +1182,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
use_mems=use_mems,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
......@@ -1182,7 +1199,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
......@@ -1207,7 +1224,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
def get_output_embeddings(self):
return self.lm_loss.input_embeddings
def prepare_inputs_for_generation(self, inputs, past, **kwargs):
def prepare_inputs_for_generation(self, inputs, past, use_mems=None, **kwargs):
# Add dummy token at the end (no attention on this one)
# At every pass, the attention values for the new token and the two last generated tokens
......@@ -1238,7 +1255,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
"input_ids": inputs,
"perm_mask": perm_mask,
"target_mapping": target_mapping,
"use_cache": kwargs["use_cache"],
"use_mems": kwargs.get("use_mems"),
}
# if past is defined in model kwargs then use it for faster decoding
......@@ -1260,7 +1277,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
......@@ -1309,7 +1326,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
use_mems=use_mems,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
......@@ -1328,7 +1345,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
......@@ -1395,7 +1412,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
......@@ -1420,7 +1437,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
use_mems=use_mems,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
......@@ -1439,7 +1456,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
......@@ -1512,7 +1529,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
target_mapping=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
......@@ -1526,6 +1543,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""
inputs = input_processing(
func=self.call,
input_ids=input_ids,
......@@ -1537,7 +1555,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
use_mems=use_mems,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
......@@ -1579,7 +1597,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
flat_input_mask,
inputs["head_mask"],
flat_inputs_embeds,
inputs["use_cache"],
inputs["use_mems"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
......@@ -1639,7 +1657,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
......@@ -1663,7 +1681,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
use_mems=use_mems,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
......@@ -1682,7 +1700,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
......@@ -1739,7 +1757,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
......@@ -1769,7 +1787,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
use_mems=use_mems,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
......@@ -1789,7 +1807,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
......
......@@ -16,6 +16,7 @@
"""
PyTorch XLNet model.
"""
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple
......@@ -876,7 +877,7 @@ XLNET_INPUTS_DOCSTRING = r"""
decoding. The token ids which have their past given to this model should not be passed as :obj:`input_ids`
as they have already been computed.
:obj::obj:`use_cache` has to be set to :obj:`True` to make use of :obj:`mems`.
:obj:`use_mems` has to be set to :obj:`True` to make use of :obj:`mems`.
perm_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, sequence_length)`, `optional`):
Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``:
......@@ -997,15 +998,15 @@ class XLNetModel(XLNetPreTrainedModel):
curr_out = curr_out[: self.reuse_len]
if self.mem_len is None or self.mem_len == 0:
# If :obj:`use_cache` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time
# If :obj:`use_mems` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time
# and returns all of the past and current hidden states.
cutoff = 0
else:
# If :obj:`use_cache` is active and `mem_len` is defined, the model returns the last `mem_len` hidden
# If :obj:`use_mems` is active and `mem_len` is defined, the model returns the last `mem_len` hidden
# states. This is the preferred setting for training and long-form generation.
cutoff = -self.mem_len
if prev_mem is None:
# if :obj:`use_cache` is active and `mem_len` is defined, the model
# if :obj:`use_mems` is active and `mem_len` is defined, the model
new_mem = curr_out[cutoff:]
else:
new_mem = torch.cat([prev_mem, curr_out], dim=0)[cutoff:]
......@@ -1080,10 +1081,11 @@ class XLNetModel(XLNetPreTrainedModel):
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs, # delete after depreciation warning is removed
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
......@@ -1091,7 +1093,18 @@ class XLNetModel(XLNetPreTrainedModel):
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
if "use_cache" in kwargs:
warnings.warn(
"The `use_cache` argument is deprecated and will be removed in a future version, use `use_mems` instead.",
FutureWarning,
)
use_mems = kwargs["use_cache"]
if self.training:
use_mems = use_mems if use_mems is not None else self.config.use_mems_train
else:
use_mems = use_mems if use_mems is not None else self.config.use_mems_eval
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
......@@ -1222,7 +1235,7 @@ class XLNetModel(XLNetPreTrainedModel):
attentions = [] if output_attentions else None
hidden_states = [] if output_hidden_states else None
for i, layer_module in enumerate(self.layer):
if use_cache:
if use_mems:
# cache new mems
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if output_hidden_states:
......@@ -1253,7 +1266,7 @@ class XLNetModel(XLNetPreTrainedModel):
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output = output.permute(1, 0, 2).contiguous()
if not use_cache:
if not use_mems:
new_mems = None
if output_hidden_states:
......@@ -1299,7 +1312,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def get_output_embeddings(self):
return self.lm_loss
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past=None, use_mems=None, **kwargs):
# Add dummy token at the end (no attention on this one)
effective_batch_size = input_ids.shape[0]
......@@ -1332,7 +1345,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
"input_ids": input_ids,
"perm_mask": perm_mask,
"target_mapping": target_mapping,
"use_cache": use_cache,
"use_mems": use_mems,
}
# if past is defined in model kwargs then use it for faster decoding
......@@ -1355,10 +1368,11 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs, # delete when `use_cache` is removed in XLNetModel
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_predict)`, `optional`):
......@@ -1407,7 +1421,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
>>> next_token_logits = outputs.logits # Logits have shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
transformer_outputs = self.transformer(
input_ids,
......@@ -1419,10 +1432,11 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
use_mems=use_mems,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
logits = self.lm_loss(transformer_outputs[0])
......@@ -1483,10 +1497,11 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs, # delete when `use_cache` is removed in XLNetModel
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
......@@ -1495,7 +1510,6 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
transformer_outputs = self.transformer(
input_ids,
......@@ -1507,10 +1521,11 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
use_mems=use_mems,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
output = transformer_outputs[0]
......@@ -1576,10 +1591,11 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs, # delete when `use_cache` is removed in XLNetModel
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
......@@ -1588,7 +1604,6 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
`input_ids` above)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
outputs = self.transformer(
input_ids,
......@@ -1600,7 +1615,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
use_mems=use_mems,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
......@@ -1673,10 +1688,11 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs, # delete when `use_cache` is removed in XLNetModel
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
......@@ -1685,7 +1701,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
:obj:`input_ids` above)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
......@@ -1708,10 +1724,11 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
target_mapping=target_mapping,
head_mask=head_mask,
inputs_embeds=flat_inputs_embeds,
use_cache=use_cache,
use_mems=use_mems,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
output = transformer_outputs[0]
......@@ -1775,10 +1792,11 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
inputs_embeds=None,
start_positions=None,
end_positions=None,
use_cache=None,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs, # delete when `use_cache` is removed in XLNetModel
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
......@@ -1791,7 +1809,6 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
outputs = self.transformer(
input_ids,
......@@ -1803,10 +1820,11 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
use_mems=use_mems,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
sequence_output = outputs[0]
......@@ -1885,10 +1903,11 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
is_impossible=None,
cls_index=None,
p_mask=None,
use_cache=None,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs, # delete when `use_cache` is removed in XLNetModel
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
......@@ -1926,7 +1945,6 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
>>> loss = outputs.loss
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
transformer_outputs = self.transformer(
input_ids,
......@@ -1938,10 +1956,11 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
use_mems=use_mems,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
hidden_states = transformer_outputs[0]
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
......
......@@ -153,7 +153,7 @@ class TFXLNetModelTester:
inputs = [input_ids_1, input_mask]
result = model(inputs)
config.mem_len = 0
config.use_mems_eval = False
model = TFXLNetModel(config)
no_mems_outputs = model(inputs)
self.parent.assertEqual(len(no_mems_outputs), 1)
......
......@@ -206,7 +206,36 @@ class XLNetModelTester:
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def create_and_check_xlnet_model_use_cache(
def create_and_check_use_mems_train(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetForSequenceClassification(config)
model.to(torch_device)
model.train()
train_size = input_ids_1.shape[0]
batch_size = 4
for i in range(train_size // batch_size + 1):
input_ids = input_ids_1[i : (i + 1) * batch_size]
labels = sequence_labels[i : (i + 1) * batch_size]
outputs = model(input_ids=input_ids, labels=labels, return_dict=True)
self.parent.assertIsNone(outputs.mems)
self.parent.assertIsNotNone(outputs.loss)
def create_and_check_xlnet_model_use_mems(
self,
config,
input_ids_1,
......@@ -234,8 +263,8 @@ class XLNetModelTester:
device=torch_device,
)
causal_mask = torch.triu(causal_mask, diagonal=0)
outputs_cache = model(input_ids_1, use_cache=True, perm_mask=causal_mask)
outputs_no_cache = model(input_ids_1, use_cache=False, perm_mask=causal_mask)
outputs_cache = model(input_ids_1, use_mems=True, perm_mask=causal_mask)
outputs_no_cache = model(input_ids_1, use_mems=False, perm_mask=causal_mask)
outputs_conf = model(input_ids_1)
self.parent.assertTrue(len(outputs_cache) == len(outputs_conf))
......@@ -525,11 +554,15 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
def test_xlnet_base_model_use_cache(self):
# checking that in auto-regressive mode, :obj:`use_cache` gives the same results
def test_xlnet_base_model_use_mems(self):
# checking that in auto-regressive mode, :obj:`use_mems` gives the same results
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_model_use_cache(*config_and_inputs)
self.model_tester.create_and_check_xlnet_model_use_mems(*config_and_inputs)
def test_seq_classification_use_mems_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_use_mems_train(*config_and_inputs)
def test_xlnet_base_model_with_att_output(self):
self.model_tester.set_seed()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment