Unverified Commit 227e0a40 authored by Teven's avatar Teven Committed by GitHub
Browse files

Fixed use of memories in XLNet (caching for language generation + warning when...

Fixed use of memories in XLNet (caching for language generation + warning when loading improper memoryless model) (#5632)

* Pytorch gpu => cpu proper device

* Memoryless XLNet warning + fixed memories during generation

* Revert "Pytorch gpu => cpu proper device"

This reverts commit 93489b36

* made black happy

* TF generation with memories

* dim => axis

* added padding_text to TF XL models

* Added comment, added TF
parent 3b7b6465
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
# limitations under the License. # limitations under the License.
""" XLNet configuration """ """ XLNet configuration """
import logging import logging
import warnings
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
...@@ -195,6 +195,17 @@ class XLNetConfig(PretrainedConfig): ...@@ -195,6 +195,17 @@ class XLNetConfig(PretrainedConfig):
self.pad_token_id = pad_token_id self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
if mem_len is None or mem_len == 0:
warnings.warn(
"This config doesn't use attention memories, a core feature of XLNet."
" Consider setting `men_len` to a non-zero value, for example "
"`xlnet = XLNetLMHeadModel.from_pretrained('xlnet-base-cased'', mem_len=1024)`,"
" for accurate training performance as well as an order of magnitude faster inference."
" Starting from version 3.5.0, the default parameter will be 1024, following"
" the implementation in https://arxiv.org/abs/1906.08237",
FutureWarning,
)
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):
return -1 return -1
......
...@@ -884,8 +884,17 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -884,8 +884,17 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
def prepare_inputs_for_generation(self, inputs, past, **kwargs): def prepare_inputs_for_generation(self, inputs, past, **kwargs):
# Add dummy token at the end (no attention on this one) # Add dummy token at the end (no attention on this one)
# At every pass, the attention values for the new token and the two last generated tokens
# are computed, the rest is reloaded from the `past` cache. A purely auto-regressive model would have
# offset = 1; offset = 2 seems to have slightly better computation.
offset = 2
effective_batch_size = inputs.shape[0] effective_batch_size = inputs.shape[0]
dummy_token = tf.zeros((effective_batch_size, 1), dtype=tf.int32) dummy_token = tf.zeros((effective_batch_size, 1), dtype=tf.int32)
if past:
inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
else:
inputs = tf.concat([inputs, dummy_token], axis=1) inputs = tf.concat([inputs, dummy_token], axis=1)
# Build permutation mask so that previous tokens don't see last token # Build permutation mask so that previous tokens don't see last token
...@@ -908,7 +917,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -908,7 +917,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
# if past is defined in model kwargs then use it for faster decoding # if past is defined in model kwargs then use it for faster decoding
if past: if past:
inputs["mems"] = past inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past)
return inputs return inputs
......
...@@ -1261,6 +1261,15 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1261,6 +1261,15 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
effective_batch_size = input_ids.shape[0] effective_batch_size = input_ids.shape[0]
dummy_token = torch.zeros((effective_batch_size, 1), dtype=torch.long, device=input_ids.device) dummy_token = torch.zeros((effective_batch_size, 1), dtype=torch.long, device=input_ids.device)
# At every pass, the attention values for the new token and the two last generated tokens
# are computed, the rest is reloaded from the `past` cache. A purely auto-regressive model would have
# offset = 1; offset = 2 seems to have slightly better computation.
offset = 2
if past:
input_ids = torch.cat([input_ids[:, -offset:], dummy_token], dim=1)
else:
input_ids = torch.cat([input_ids, dummy_token], dim=1) input_ids = torch.cat([input_ids, dummy_token], dim=1)
# Build permutation mask so that previous tokens don't see last token # Build permutation mask so that previous tokens don't see last token
...@@ -1285,7 +1294,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1285,7 +1294,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
# if past is defined in model kwargs then use it for faster decoding # if past is defined in model kwargs then use it for faster decoding
if past: if past:
inputs["mems"] = past inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past)
return inputs return inputs
......
...@@ -653,7 +653,12 @@ class TextGenerationPipeline(Pipeline): ...@@ -653,7 +653,12 @@ class TextGenerationPipeline(Pipeline):
for prompt_text in text_inputs: for prompt_text in text_inputs:
# Manage correct placement of the tensors # Manage correct placement of the tensors
with self.device_placement(): with self.device_placement():
if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]: if self.model.__class__.__name__ in [
"XLNetLMHeadModel",
"TransfoXLLMHeadModel",
"TFXLNetLMHeadModel",
"TFTransfoXLLMHeadModel",
]:
# For XLNet and TransformerXL we had an article to the prompt to give more state to the model. # For XLNet and TransformerXL we had an article to the prompt to give more state to the model.
padding_text = self.PADDING_TEXT + self.tokenizer.eos_token padding_text = self.PADDING_TEXT + self.tokenizer.eos_token
padding = self._parse_and_tokenize(padding_text, padding=False, add_special_tokens=False) padding = self._parse_and_tokenize(padding_text, padding=False, add_special_tokens=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment