Commit 0f9fc4fb authored by thomwolf's avatar thomwolf
Browse files

adding option to desactivate past/memory outputs

parent 2a4fef83
...@@ -53,7 +53,8 @@ class PretrainedConfig(object): ...@@ -53,7 +53,8 @@ class PretrainedConfig(object):
self.num_labels = kwargs.pop('num_labels', 2) self.num_labels = kwargs.pop('num_labels', 2)
self.output_attentions = kwargs.pop('output_attentions', False) self.output_attentions = kwargs.pop('output_attentions', False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False) self.output_hidden_states = kwargs.pop('output_hidden_states', False)
self.torchscript = kwargs.pop('torchscript', False) self.output_past = kwargs.pop('output_past', True) # Not used by all models
self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop('use_bfloat16', False) self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
self.pruned_heads = kwargs.pop('pruned_heads', {}) self.pruned_heads = kwargs.pop('pruned_heads', {})
......
...@@ -269,16 +269,16 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -269,16 +269,16 @@ class CTRLModel(CTRLPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super(CTRLModel, self).__init__(config) super(CTRLModel, self).__init__(config)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.d_model_size = config.n_embd self.d_model_size = config.n_embd
self.num_layers = config.n_layer self.num_layers = config.n_layer
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size, torch.float) self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size, torch.float)
self.output_attentions = config.output_attentions
self.w = nn.Embedding(config.vocab_size, config.n_embd) self.w = nn.Embedding(config.vocab_size, config.n_embd)
self.dropout = nn.Dropout(config.embd_pdrop) self.dropout = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([EncoderLayer(config.n_embd, self.h = nn.ModuleList([EncoderLayer(config.n_embd,
config.n_head, config.n_head,
...@@ -378,6 +378,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -378,6 +378,7 @@ class CTRLModel(CTRLPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask[i]) head_mask=head_mask[i])
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
if self.output_past:
presents = presents + (present,) presents = presents + (present,)
if self.output_attentions: if self.output_attentions:
...@@ -388,7 +389,9 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -388,7 +389,9 @@ class CTRLModel(CTRLPreTrainedModel):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states, presents) outputs = (hidden_states,)
if self.output_past:
outputs = outputs + (presents,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if self.output_attentions:
......
...@@ -347,6 +347,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -347,6 +347,7 @@ class GPT2Model(GPT2PreTrainedModel):
super(GPT2Model, self).__init__(config) super(GPT2Model, self).__init__(config)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd)
...@@ -440,6 +441,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -440,6 +441,7 @@ class GPT2Model(GPT2PreTrainedModel):
head_mask=head_mask[i]) head_mask=head_mask[i])
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
if self.output_past:
presents = presents + (present,) presents = presents + (present,)
if self.output_attentions: if self.output_attentions:
...@@ -452,7 +454,9 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -452,7 +454,9 @@ class GPT2Model(GPT2PreTrainedModel):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states, presents) outputs = (hidden_states,)
if self.output_past:
outputs = outputs + (presents,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if self.output_attentions:
...@@ -460,7 +464,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -460,7 +464,7 @@ class GPT2Model(GPT2PreTrainedModel):
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:] attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions) all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # last hidden state, presents, (all hidden_states), (attentions) return outputs # last hidden state, (presents), (all hidden_states), (attentions)
@add_start_docstrings("""The GPT2 Model transformer with a language modeling head on top @add_start_docstrings("""The GPT2 Model transformer with a language modeling head on top
......
...@@ -168,12 +168,14 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -168,12 +168,14 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super(TFCTRLMainLayer, self).__init__(**kwargs) super(TFCTRLMainLayer, self).__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.d_model_size = config.n_embd self.d_model_size = config.n_embd
self.num_layers = config.n_layer self.num_layers = config.n_layer
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size) self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size)
self.output_attentions = config.output_attentions
self.w = TFSharedEmbeddings(config.vocab_size, self.w = TFSharedEmbeddings(config.vocab_size,
config.n_embd, config.n_embd,
...@@ -290,6 +292,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -290,6 +292,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i]], training=training) outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i]], training=training)
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
if self.output_past:
presents = presents + (present,) presents = presents + (present,)
if self.output_attentions: if self.output_attentions:
...@@ -300,7 +304,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -300,7 +304,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states, presents) outputs = (hidden_states,)
if self.output_past:
outputs = outputs + (presents,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if self.output_attentions:
......
...@@ -354,6 +354,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -354,6 +354,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
super(TFXLNetMainLayer, self).__init__(**kwargs) super(TFXLNetMainLayer, self).__init__(**kwargs)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_past = config.output_past
self.mem_len = config.mem_len self.mem_len = config.mem_len
self.reuse_len = config.reuse_len self.reuse_len = config.reuse_len
...@@ -413,9 +414,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -413,9 +414,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
def cache_mem(self, curr_out, prev_mem): def cache_mem(self, curr_out, prev_mem):
"""cache hidden states into memory.""" """cache hidden states into memory."""
if self.mem_len is None or self.mem_len == 0:
return None
else:
if self.reuse_len is not None and self.reuse_len > 0: if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[:self.reuse_len] curr_out = curr_out[:self.reuse_len]
...@@ -538,7 +536,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -538,7 +536,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
raise ValueError('Unsupported attention type: {}'.format(self.attn_type)) raise ValueError('Unsupported attention type: {}'.format(self.attn_type))
# data mask: input mask & perm mask # data mask: input mask & perm mask
assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) " assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) " \
"or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one." "or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one."
if input_mask is None and attention_mask is not None: if input_mask is None and attention_mask is not None:
input_mask = 1.0 - attention_mask input_mask = 1.0 - attention_mask
...@@ -624,6 +622,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -624,6 +622,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
hidden_states = [] hidden_states = []
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
# cache new mems # cache new mems
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if self.output_hidden_states: if self.output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
...@@ -642,7 +641,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -642,7 +641,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
output = self.dropout(output_g if output_g is not None else output_h, training=training) output = self.dropout(output_g if output_g is not None else output_h, training=training)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs = (tf.transpose(output, perm=(1, 0, 2)), new_mems) outputs = (tf.transpose(output, perm=(1, 0, 2)),)
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
outputs = outputs + (new_mems,)
if self.output_hidden_states: if self.output_hidden_states:
if output_g is not None: if output_g is not None:
hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs) hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs)
...@@ -653,7 +656,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -653,7 +656,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
outputs = outputs + (attentions,) outputs = outputs + (attentions,)
return outputs # outputs, new_mems, (hidden_states), (attentions) return outputs # outputs, (new_mems), (hidden_states), (attentions)
class TFXLNetPreTrainedModel(TFPreTrainedModel): class TFXLNetPreTrainedModel(TFPreTrainedModel):
...@@ -768,7 +771,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel): ...@@ -768,7 +771,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, hidden_size)`` **last_hidden_state**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model. Sequence of hidden-states at the last layer of the model.
**mems**: **mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``tf.Tensor`` (one for each layer): list of ``tf.Tensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
...@@ -810,7 +813,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel): ...@@ -810,7 +813,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**prediction_scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` **prediction_scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**mems**: **mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``tf.Tensor`` (one for each layer): list of ``tf.Tensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
...@@ -854,7 +857,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel): ...@@ -854,7 +857,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
return outputs # return logits, mems, (hidden states), (attentions) return outputs # return logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of @add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
...@@ -865,7 +868,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel): ...@@ -865,7 +868,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**logits**: ``tf.Tensor`` of shape ``(batch_size, config.num_labels)`` **logits**: ``tf.Tensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax). Classification (or regression if config.num_labels==1) scores (before SoftMax).
**mems**: **mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``tf.Tensor`` (one for each layer): list of ``tf.Tensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
...@@ -909,7 +912,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel): ...@@ -909,7 +912,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
return outputs # return logits, mems, (hidden states), (attentions) return outputs # return logits, (mems), (hidden states), (attentions)
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of # @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
...@@ -923,6 +926,11 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel): ...@@ -923,6 +926,11 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
Span-start scores (before SoftMax). Span-start scores (before SoftMax).
**end_scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length,)`` **end_scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax). Span-end scores (before SoftMax).
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``tf.Tensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings) list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
...@@ -962,7 +970,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel): ...@@ -962,7 +970,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
outputs = (start_logits, end_logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it outputs = (start_logits, end_logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
return outputs # start_logits, end_logits, (hidden_states), (attentions) return outputs # start_logits, end_logits, (mems), (hidden_states), (attentions)
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of # @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
# the hidden-states output to compute `span start logits` and `span end logits`). """, # the hidden-states output to compute `span start logits` and `span end logits`). """,
......
...@@ -555,7 +555,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -555,7 +555,7 @@ class XLNetModel(XLNetPreTrainedModel):
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model. Sequence of hidden-states at the last layer of the model.
**mems**: **mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer): list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
...@@ -581,6 +581,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -581,6 +581,7 @@ class XLNetModel(XLNetPreTrainedModel):
super(XLNetModel, self).__init__(config) super(XLNetModel, self).__init__(config)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_past = config.output_past
self.mem_len = config.mem_len self.mem_len = config.mem_len
self.reuse_len = config.reuse_len self.reuse_len = config.reuse_len
...@@ -637,9 +638,6 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -637,9 +638,6 @@ class XLNetModel(XLNetPreTrainedModel):
def cache_mem(self, curr_out, prev_mem): def cache_mem(self, curr_out, prev_mem):
"""cache hidden states into memory.""" """cache hidden states into memory."""
if self.mem_len is None or self.mem_len == 0:
return None
else:
if self.reuse_len is not None and self.reuse_len > 0: if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[:self.reuse_len] curr_out = curr_out[:self.reuse_len]
...@@ -817,6 +815,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -817,6 +815,7 @@ class XLNetModel(XLNetPreTrainedModel):
attentions = [] attentions = []
hidden_states = [] hidden_states = []
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
# cache new mems # cache new mems
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if self.output_hidden_states: if self.output_hidden_states:
...@@ -836,7 +835,11 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -836,7 +835,11 @@ class XLNetModel(XLNetPreTrainedModel):
output = self.dropout(output_g if output_g is not None else output_h) output = self.dropout(output_g if output_g is not None else output_h)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs = (output.permute(1, 0, 2).contiguous(), new_mems) outputs = (output.permute(1, 0, 2).contiguous(),)
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
outputs = outputs + (new_mems,)
if self.output_hidden_states: if self.output_hidden_states:
if output_g is not None: if output_g is not None:
hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs) hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
...@@ -847,7 +850,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -847,7 +850,7 @@ class XLNetModel(XLNetPreTrainedModel):
attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions) attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
outputs = outputs + (attentions,) outputs = outputs + (attentions,)
return outputs # outputs, new_mems, (hidden_states), (attentions) return outputs # outputs, (new_mems), (hidden_states), (attentions)
@add_start_docstrings("""XLNet Model with a language modeling head on top @add_start_docstrings("""XLNet Model with a language modeling head on top
...@@ -867,7 +870,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -867,7 +870,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
Language modeling loss. Language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**mems**: **mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer): list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
...@@ -932,7 +935,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -932,7 +935,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
labels.view(-1)) labels.view(-1))
outputs = (loss,) + outputs outputs = (loss,) + outputs
return outputs # return (loss), logits, mems, (hidden states), (attentions) return outputs # return (loss), logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of @add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
...@@ -951,7 +954,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -951,7 +954,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Classification (or regression if config.num_labels==1) loss. Classification (or regression if config.num_labels==1) loss.
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax). Classification (or regression if config.num_labels==1) scores (before SoftMax).
**mems**: **mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer): list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
...@@ -1011,7 +1014,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1011,7 +1014,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs outputs = (loss,) + outputs
return outputs # return (loss), logits, mems, (hidden states), (attentions) return outputs # return (loss), logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a multiple choice classification head on top (a linear layer on top of @add_start_docstrings("""XLNet Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RACE/SWAG tasks. """, the pooled output and a softmax) e.g. for RACE/SWAG tasks. """,
...@@ -1046,6 +1049,11 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): ...@@ -1046,6 +1049,11 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
**classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above). of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax). Classification scores (before SoftMax).
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
...@@ -1102,7 +1110,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): ...@@ -1102,7 +1110,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
loss = loss_fct(reshaped_logits, labels.view(-1)) loss = loss_fct(reshaped_logits, labels.view(-1))
outputs = (loss,) + outputs outputs = (loss,) + outputs
return outputs # return (loss), logits, mems, (hidden states), (attentions) return outputs # return (loss), logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
...@@ -1126,7 +1134,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): ...@@ -1126,7 +1134,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
Span-start scores (before SoftMax). Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax). Span-end scores (before SoftMax).
**mems**: **mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer): list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
...@@ -1197,7 +1205,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): ...@@ -1197,7 +1205,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) return outputs # (loss), start_logits, end_logits, (mems), (hidden_states), (attentions)
@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
...@@ -1239,7 +1247,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1239,7 +1247,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
**cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.FloatTensor`` of shape ``(batch_size,)`` ``torch.FloatTensor`` of shape ``(batch_size,)``
Log probabilities for the ``is_impossible`` label of the answers. Log probabilities for the ``is_impossible`` label of the answers.
**mems**: **mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer): list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
......
...@@ -161,6 +161,10 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -161,6 +161,10 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
"outputs": outputs.numpy(), "outputs": outputs.numpy(),
} }
model.config.mem_len = 0
no_mems_outputs = model(inputs)
self.parent.assertEqual(len(no_mems_outputs), 1)
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["outputs"].shape), list(result["outputs"].shape),
[self.batch_size, self.seq_length, self.hidden_size]) [self.batch_size, self.seq_length, self.hidden_size])
......
...@@ -150,6 +150,10 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -150,6 +150,10 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
"outputs": outputs, "outputs": outputs,
} }
model.config.mem_len = 0
no_mems_outputs = model(input_ids_1)
self.parent.assertEqual(len(no_mems_outputs), 1)
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["outputs"].size()), list(result["outputs"].size()),
[self.batch_size, self.seq_length, self.hidden_size]) [self.batch_size, self.seq_length, self.hidden_size])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment