Unverified Commit 88ef8893 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Add caching mechanism to BERT, RoBERTa (#9183)

* add past_key_values

* add use_cache option

* make mask before cutting ids

* adjust position_ids according to past_key_values

* flatten past_key_values

* fix positional embeds

* fix _reorder_cache

* set use_cache to false when not decoder, fix attention mask init

* add test for caching

* add past_key_values for Roberta

* fix position embeds

* add caching test for roberta

* add doc

* make style

* doc, fix attention mask, test

* small fixes

* adress patrick's comments

* input_ids shouldn't start with pad token

* use_cache only when decoder

* make consistent with bert

* make copies consistent

* add use_cache to encoder

* add past_key_values to tapas attention

* apply suggestions from code review

* make coppies consistent

* add attn mask in tests

* remove copied from longformer

* apply suggestions from code review

* fix bart test

* nit

* simplify model outputs

* fix doc

* fix output ordering
parent a1cb6e98
...@@ -126,13 +126,6 @@ CausalLMOutputWithCrossAttentions ...@@ -126,13 +126,6 @@ CausalLMOutputWithCrossAttentions
:members: :members:
CausalLMOutputWithPastAndCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_outputs.CausalLMOutputWithPastAndCrossAttentions
:members:
CausalLMOutputWithPast CausalLMOutputWithPast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -175,11 +175,19 @@ class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): ...@@ -175,11 +175,19 @@ class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads. weighted average in the cross-attention heads.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`torch.FloatTensor` tuples of length :obj:`config.n_layers`, with each tuple containing the
cached key, value states of the self-attention and the cross-attention layers if model is used in
encoder-decoder setting. Only relevant if ``config.is_decoder = True``.
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding.
""" """
last_hidden_state: torch.FloatTensor = None last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
...@@ -379,53 +387,18 @@ class CausalLMOutputWithCrossAttentions(ModelOutput): ...@@ -379,53 +387,18 @@ class CausalLMOutputWithCrossAttentions(ModelOutput):
Cross attentions weights after the attention softmax, used to compute the weighted average in the Cross attentions weights after the attention softmax, used to compute the weighted average in the
cross-attention heads. cross-attention heads.
""" past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`torch.FloatTensor` tuples of length :obj:`config.n_layers`, with each tuple containing the
loss: Optional[torch.FloatTensor] = None cached key, value states of the self-attention and the cross-attention layers if model is used in
logits: torch.FloatTensor = None encoder-decoder setting. Only relevant if ``config.is_decoder = True``.
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class CausalLMOutputWithPastAndCrossAttentions(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Language modeling loss (for next-token prediction).
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding. :obj:`past_key_values` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Cross attentions weights after the attention softmax, used to compute the weighted average in the
cross-attention heads.
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
......
...@@ -217,7 +217,9 @@ class AlbertEmbeddings(nn.Module): ...@@ -217,7 +217,9 @@ class AlbertEmbeddings(nn.Module):
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
...@@ -226,7 +228,7 @@ class AlbertEmbeddings(nn.Module): ...@@ -226,7 +228,7 @@ class AlbertEmbeddings(nn.Module):
seq_length = input_shape[1] seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, :seq_length] position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
......
...@@ -98,6 +98,9 @@ class BertConfig(PretrainedConfig): ...@@ -98,6 +98,9 @@ class BertConfig(PretrainedConfig):
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to <https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.) `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__. <https://arxiv.org/abs/2009.13658>`__.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``.
Examples:: Examples::
...@@ -131,6 +134,7 @@ class BertConfig(PretrainedConfig): ...@@ -131,6 +134,7 @@ class BertConfig(PretrainedConfig):
pad_token_id=0, pad_token_id=0,
gradient_checkpointing=False, gradient_checkpointing=False,
position_embedding_type="absolute", position_embedding_type="absolute",
use_cache=True,
**kwargs **kwargs
): ):
super().__init__(pad_token_id=pad_token_id, **kwargs) super().__init__(pad_token_id=pad_token_id, **kwargs)
...@@ -149,3 +153,4 @@ class BertConfig(PretrainedConfig): ...@@ -149,3 +153,4 @@ class BertConfig(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
...@@ -36,7 +36,7 @@ from ...file_utils import ( ...@@ -36,7 +36,7 @@ from ...file_utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
MaskedLMOutput, MaskedLMOutput,
...@@ -180,7 +180,9 @@ class BertEmbeddings(nn.Module): ...@@ -180,7 +180,9 @@ class BertEmbeddings(nn.Module):
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
...@@ -189,7 +191,7 @@ class BertEmbeddings(nn.Module): ...@@ -189,7 +191,7 @@ class BertEmbeddings(nn.Module):
seq_length = input_shape[1] seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, :seq_length] position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
...@@ -230,6 +232,8 @@ class BertSelfAttention(nn.Module): ...@@ -230,6 +232,8 @@ class BertSelfAttention(nn.Module):
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(*new_x_shape)
...@@ -242,6 +246,7 @@ class BertSelfAttention(nn.Module): ...@@ -242,6 +246,7 @@ class BertSelfAttention(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_value=None,
output_attentions=False, output_attentions=False,
): ):
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
...@@ -249,17 +254,37 @@ class BertSelfAttention(nn.Module): ...@@ -249,17 +254,37 @@ class BertSelfAttention(nn.Module):
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be # and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to. # such that the encoder's padding tokens are not attended to.
if encoder_hidden_states is not None: is_cross_attention = encoder_hidden_states is not None
mixed_key_layer = self.key(encoder_hidden_states)
mixed_value_layer = self.value(encoder_hidden_states) if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else: else:
mixed_key_layer = self.key(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states))
mixed_value_layer = self.value(hidden_states) value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer) query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer) if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
...@@ -303,6 +328,9 @@ class BertSelfAttention(nn.Module): ...@@ -303,6 +328,9 @@ class BertSelfAttention(nn.Module):
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs return outputs
...@@ -352,6 +380,7 @@ class BertAttention(nn.Module): ...@@ -352,6 +380,7 @@ class BertAttention(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_value=None,
output_attentions=False, output_attentions=False,
): ):
self_outputs = self.self( self_outputs = self.self(
...@@ -360,6 +389,7 @@ class BertAttention(nn.Module): ...@@ -360,6 +389,7 @@ class BertAttention(nn.Module):
head_mask, head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions, output_attentions,
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
...@@ -417,36 +447,60 @@ class BertLayer(nn.Module): ...@@ -417,36 +447,60 @@ class BertLayer(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_value=None,
output_attentions=False, output_attentions=False,
): ):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask, head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
) )
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
assert hasattr( assert hasattr(
self, "crossattention" self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention( cross_attention_outputs = self.crossattention(
attention_output, attention_output,
attention_mask, attention_mask,
head_mask, head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
cross_attn_past_key_value,
output_attentions, output_attentions,
) )
attention_output = cross_attention_outputs[0] attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward( layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
) )
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs return outputs
def feed_forward_chunk(self, attention_output): def feed_forward_chunk(self, attention_output):
...@@ -468,6 +522,8 @@ class BertEncoder(nn.Module): ...@@ -468,6 +522,8 @@ class BertEncoder(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
return_dict=True, return_dict=True,
...@@ -475,17 +531,19 @@ class BertEncoder(nn.Module): ...@@ -475,17 +531,19 @@ class BertEncoder(nn.Module):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False):
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
return module(*inputs, output_attentions) return module(*inputs, past_key_value, output_attentions)
return custom_forward return custom_forward
...@@ -504,9 +562,13 @@ class BertEncoder(nn.Module): ...@@ -504,9 +562,13 @@ class BertEncoder(nn.Module):
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions, output_attentions,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention: if self.config.add_cross_attention:
...@@ -518,11 +580,18 @@ class BertEncoder(nn.Module): ...@@ -518,11 +580,18 @@ class BertEncoder(nn.Module):
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None if v is not None
) )
return BaseModelOutputWithCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions,
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
...@@ -799,6 +868,8 @@ class BertModel(BertPreTrainedModel): ...@@ -799,6 +868,8 @@ class BertModel(BertPreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -813,6 +884,15 @@ class BertModel(BertPreTrainedModel): ...@@ -813,6 +884,15 @@ class BertModel(BertPreTrainedModel):
- 1 for tokens that are **not masked**, - 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -820,19 +900,29 @@ class BertModel(BertPreTrainedModel): ...@@ -820,19 +900,29 @@ class BertModel(BertPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
batch_size, seq_length = input_shape
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device) attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
...@@ -859,7 +949,11 @@ class BertModel(BertPreTrainedModel): ...@@ -859,7 +949,11 @@ class BertModel(BertPreTrainedModel):
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings( embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
) )
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
...@@ -867,6 +961,8 @@ class BertModel(BertPreTrainedModel): ...@@ -867,6 +961,8 @@ class BertModel(BertPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -880,6 +976,7 @@ class BertModel(BertPreTrainedModel): ...@@ -880,6 +976,7 @@ class BertModel(BertPreTrainedModel):
return BaseModelOutputWithPoolingAndCrossAttentions( return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions, cross_attentions=encoder_outputs.cross_attentions,
...@@ -1029,6 +1126,8 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -1029,6 +1126,8 @@ class BertLMHeadModel(BertPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
labels=None, labels=None,
past_key_values=None,
use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -1047,6 +1146,15 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -1047,6 +1146,15 @@ class BertLMHeadModel(BertPreTrainedModel):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
Returns: Returns:
...@@ -1066,6 +1174,8 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -1066,6 +1174,8 @@ class BertLMHeadModel(BertPreTrainedModel):
>>> prediction_logits = outputs.logits >>> prediction_logits = outputs.logits
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
outputs = self.bert( outputs = self.bert(
input_ids, input_ids,
...@@ -1076,6 +1186,8 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -1076,6 +1186,8 @@ class BertLMHeadModel(BertPreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -1099,20 +1211,30 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -1099,20 +1211,30 @@ class BertLMHeadModel(BertPreTrainedModel):
return CausalLMOutputWithCrossAttentions( return CausalLMOutputWithCrossAttentions(
loss=lm_loss, loss=lm_loss,
logits=prediction_scores, logits=prediction_scores,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions, cross_attentions=outputs.cross_attentions,
) )
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask} return {"input_ids": input_ids, "attention_mask": attention_mask}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class BertForMaskedLM(BertPreTrainedModel): class BertForMaskedLM(BertPreTrainedModel):
......
...@@ -26,7 +26,7 @@ from ...file_utils import ( ...@@ -26,7 +26,7 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_outputs import BaseModelOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
from ..bert.modeling_bert import BertEncoder from ..bert.modeling_bert import BertEncoder
...@@ -144,7 +144,7 @@ class BertGenerationEmbeddings(nn.Module): ...@@ -144,7 +144,7 @@ class BertGenerationEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0):
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
...@@ -153,7 +153,7 @@ class BertGenerationEmbeddings(nn.Module): ...@@ -153,7 +153,7 @@ class BertGenerationEmbeddings(nn.Module):
seq_length = input_shape[1] seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, :seq_length] position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
...@@ -297,7 +297,7 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): ...@@ -297,7 +297,7 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/bert_for_seq_generation_L-24_bbc_encoder", checkpoint="google/bert_for_seq_generation_L-24_bbc_encoder",
output_type=BaseModelOutputWithCrossAttentions, output_type=BaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def forward( def forward(
...@@ -309,6 +309,8 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): ...@@ -309,6 +309,8 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -321,6 +323,15 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): ...@@ -321,6 +323,15 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: ``1`` for the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: ``1`` for
tokens that are NOT MASKED, ``0`` for MASKED tokens. tokens that are NOT MASKED, ``0`` for MASKED tokens.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -332,19 +343,28 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): ...@@ -332,19 +343,28 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
batch_size, seq_length = input_shape
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device) attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) extended_attention_mask = None
if not use_cache:
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, device
)
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
...@@ -364,7 +384,12 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): ...@@ -364,7 +384,12 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
...@@ -372,6 +397,8 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): ...@@ -372,6 +397,8 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -381,8 +408,9 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): ...@@ -381,8 +408,9 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
if not return_dict: if not return_dict:
return (sequence_output,) + encoder_outputs[1:] return (sequence_output,) + encoder_outputs[1:]
return BaseModelOutputWithCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions, cross_attentions=encoder_outputs.cross_attentions,
...@@ -437,6 +465,8 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): ...@@ -437,6 +465,8 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
labels=None, labels=None,
past_key_values=None,
use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -455,6 +485,15 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): ...@@ -455,6 +485,15 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
Returns: Returns:
...@@ -474,6 +513,8 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): ...@@ -474,6 +513,8 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
>>> prediction_logits = outputs.logits >>> prediction_logits = outputs.logits
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
outputs = self.bert( outputs = self.bert(
input_ids, input_ids,
...@@ -483,6 +524,8 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): ...@@ -483,6 +524,8 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -506,16 +549,26 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): ...@@ -506,16 +549,26 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
return CausalLMOutputWithCrossAttentions( return CausalLMOutputWithCrossAttentions(
loss=lm_loss, loss=lm_loss,
logits=prediction_scores, logits=prediction_scores,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions, cross_attentions=outputs.cross_attentions,
) )
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask} return {"input_ids": input_ids, "attention_mask": attention_mask}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
...@@ -33,6 +33,7 @@ from ...file_utils import ( ...@@ -33,6 +33,7 @@ from ...file_utils import (
) )
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithCrossAttentions, BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPastAndCrossAttentions,
MaskedLMOutput, MaskedLMOutput,
MultipleChoiceModelOutput, MultipleChoiceModelOutput,
QuestionAnsweringModelOutput, QuestionAnsweringModelOutput,
...@@ -168,7 +169,9 @@ class ElectraEmbeddings(nn.Module): ...@@ -168,7 +169,9 @@ class ElectraEmbeddings(nn.Module):
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
...@@ -177,7 +180,7 @@ class ElectraEmbeddings(nn.Module): ...@@ -177,7 +180,7 @@ class ElectraEmbeddings(nn.Module):
seq_length = input_shape[1] seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, :seq_length] position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
...@@ -219,6 +222,8 @@ class ElectraSelfAttention(nn.Module): ...@@ -219,6 +222,8 @@ class ElectraSelfAttention(nn.Module):
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(*new_x_shape)
...@@ -231,6 +236,7 @@ class ElectraSelfAttention(nn.Module): ...@@ -231,6 +236,7 @@ class ElectraSelfAttention(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_value=None,
output_attentions=False, output_attentions=False,
): ):
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
...@@ -238,17 +244,37 @@ class ElectraSelfAttention(nn.Module): ...@@ -238,17 +244,37 @@ class ElectraSelfAttention(nn.Module):
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be # and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to. # such that the encoder's padding tokens are not attended to.
if encoder_hidden_states is not None: is_cross_attention = encoder_hidden_states is not None
mixed_key_layer = self.key(encoder_hidden_states)
mixed_value_layer = self.value(encoder_hidden_states) if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else: else:
mixed_key_layer = self.key(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states))
mixed_value_layer = self.value(hidden_states) value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer) query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer) if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
...@@ -292,6 +318,9 @@ class ElectraSelfAttention(nn.Module): ...@@ -292,6 +318,9 @@ class ElectraSelfAttention(nn.Module):
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs return outputs
...@@ -343,6 +372,7 @@ class ElectraAttention(nn.Module): ...@@ -343,6 +372,7 @@ class ElectraAttention(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_value=None,
output_attentions=False, output_attentions=False,
): ):
self_outputs = self.self( self_outputs = self.self(
...@@ -351,6 +381,7 @@ class ElectraAttention(nn.Module): ...@@ -351,6 +381,7 @@ class ElectraAttention(nn.Module):
head_mask, head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions, output_attentions,
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
...@@ -411,36 +442,60 @@ class ElectraLayer(nn.Module): ...@@ -411,36 +442,60 @@ class ElectraLayer(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_value=None,
output_attentions=False, output_attentions=False,
): ):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask, head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
) )
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
assert hasattr( assert hasattr(
self, "crossattention" self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention( cross_attention_outputs = self.crossattention(
attention_output, attention_output,
attention_mask, attention_mask,
head_mask, head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
cross_attn_past_key_value,
output_attentions, output_attentions,
) )
attention_output = cross_attention_outputs[0] attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward( layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
) )
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs return outputs
def feed_forward_chunk(self, attention_output): def feed_forward_chunk(self, attention_output):
...@@ -463,6 +518,8 @@ class ElectraEncoder(nn.Module): ...@@ -463,6 +518,8 @@ class ElectraEncoder(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
return_dict=True, return_dict=True,
...@@ -470,17 +527,19 @@ class ElectraEncoder(nn.Module): ...@@ -470,17 +527,19 @@ class ElectraEncoder(nn.Module):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False):
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
return module(*inputs, output_attentions) return module(*inputs, past_key_value, output_attentions)
return custom_forward return custom_forward
...@@ -499,9 +558,13 @@ class ElectraEncoder(nn.Module): ...@@ -499,9 +558,13 @@ class ElectraEncoder(nn.Module):
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions, output_attentions,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention: if self.config.add_cross_attention:
...@@ -513,11 +576,18 @@ class ElectraEncoder(nn.Module): ...@@ -513,11 +576,18 @@ class ElectraEncoder(nn.Module):
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None if v is not None
) )
return BaseModelOutputWithCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions,
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
......
...@@ -345,11 +345,11 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -345,11 +345,11 @@ class EncoderDecoderModel(PreTrainedModel):
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, # TODO: (PVP) implement :obj:`use_cache` past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
labels=None, labels=None,
use_cache=None, # TODO: (PVP) implement :obj:`use_cache` use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -413,18 +413,19 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -413,18 +413,19 @@ class EncoderDecoderModel(PreTrainedModel):
labels=labels, labels=labels,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict, return_dict=return_dict,
**kwargs_decoder, **kwargs_decoder,
) )
# TODO(PVP): currently it is not possible to use `past`
if not return_dict: if not return_dict:
return decoder_outputs + encoder_outputs return decoder_outputs + encoder_outputs
return Seq2SeqLMOutput( return Seq2SeqLMOutput(
loss=decoder_outputs.loss, loss=decoder_outputs.loss,
logits=decoder_outputs.logits, logits=decoder_outputs.logits,
past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions, cross_attentions=decoder_outputs.cross_attentions,
...@@ -433,24 +434,19 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -433,24 +434,19 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
) )
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, encoder_outputs=None, **kwargs): def prepare_inputs_for_generation(
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids) self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = { input_dict = {
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"], "decoder_input_ids": decoder_inputs["input_ids"],
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"past_key_values": past,
"use_cache": use_cache,
} }
# Ideally all models should have a :obj:`use_cache`
# leave following to ifs until all have it implemented
if "use_cache" in decoder_inputs:
input_dict["decoder_use_cache"] = decoder_inputs["use_cache"]
if "past_key_values" in decoder_inputs:
input_dict["past_key_values"] = decoder_inputs["past_key_values"]
return input_dict return input_dict
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past, beam_idx):
......
...@@ -33,7 +33,7 @@ from ...file_utils import ( ...@@ -33,7 +33,7 @@ from ...file_utils import (
) )
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions,
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
) )
from ...modeling_utils import ( from ...modeling_utils import (
...@@ -851,7 +851,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -851,7 +851,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="gpt2", checkpoint="gpt2",
output_type=CausalLMOutputWithPastAndCrossAttentions, output_type=CausalLMOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def forward( def forward(
...@@ -916,7 +916,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -916,7 +916,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
output = (lm_logits,) + transformer_outputs[1:] output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPastAndCrossAttentions( return CausalLMOutputWithCrossAttentions(
loss=loss, loss=loss,
logits=lm_logits, logits=lm_logits,
past_key_values=transformer_outputs.past_key_values, past_key_values=transformer_outputs.past_key_values,
......
...@@ -24,7 +24,7 @@ from torch.nn import CrossEntropyLoss ...@@ -24,7 +24,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
MaskedLMOutput, MaskedLMOutput,
TokenClassifierOutput, TokenClassifierOutput,
...@@ -151,6 +151,8 @@ class LayoutLMSelfAttention(nn.Module): ...@@ -151,6 +151,8 @@ class LayoutLMSelfAttention(nn.Module):
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(*new_x_shape)
...@@ -163,6 +165,7 @@ class LayoutLMSelfAttention(nn.Module): ...@@ -163,6 +165,7 @@ class LayoutLMSelfAttention(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_value=None,
output_attentions=False, output_attentions=False,
): ):
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
...@@ -170,17 +173,37 @@ class LayoutLMSelfAttention(nn.Module): ...@@ -170,17 +173,37 @@ class LayoutLMSelfAttention(nn.Module):
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be # and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to. # such that the encoder's padding tokens are not attended to.
if encoder_hidden_states is not None: is_cross_attention = encoder_hidden_states is not None
mixed_key_layer = self.key(encoder_hidden_states)
mixed_value_layer = self.value(encoder_hidden_states) if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else: else:
mixed_key_layer = self.key(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states))
mixed_value_layer = self.value(hidden_states) value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer) query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer) if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
...@@ -224,6 +247,9 @@ class LayoutLMSelfAttention(nn.Module): ...@@ -224,6 +247,9 @@ class LayoutLMSelfAttention(nn.Module):
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs return outputs
...@@ -275,6 +301,7 @@ class LayoutLMAttention(nn.Module): ...@@ -275,6 +301,7 @@ class LayoutLMAttention(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_value=None,
output_attentions=False, output_attentions=False,
): ):
self_outputs = self.self( self_outputs = self.self(
...@@ -283,6 +310,7 @@ class LayoutLMAttention(nn.Module): ...@@ -283,6 +310,7 @@ class LayoutLMAttention(nn.Module):
head_mask, head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions, output_attentions,
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
...@@ -343,36 +371,60 @@ class LayoutLMLayer(nn.Module): ...@@ -343,36 +371,60 @@ class LayoutLMLayer(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_value=None,
output_attentions=False, output_attentions=False,
): ):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask, head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
) )
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
assert hasattr( assert hasattr(
self, "crossattention" self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention( cross_attention_outputs = self.crossattention(
attention_output, attention_output,
attention_mask, attention_mask,
head_mask, head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
cross_attn_past_key_value,
output_attentions, output_attentions,
) )
attention_output = cross_attention_outputs[0] attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward( layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
) )
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs return outputs
def feed_forward_chunk(self, attention_output): def feed_forward_chunk(self, attention_output):
...@@ -395,6 +447,8 @@ class LayoutLMEncoder(nn.Module): ...@@ -395,6 +447,8 @@ class LayoutLMEncoder(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
return_dict=True, return_dict=True,
...@@ -402,17 +456,19 @@ class LayoutLMEncoder(nn.Module): ...@@ -402,17 +456,19 @@ class LayoutLMEncoder(nn.Module):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False):
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
return module(*inputs, output_attentions) return module(*inputs, past_key_value, output_attentions)
return custom_forward return custom_forward
...@@ -431,9 +487,13 @@ class LayoutLMEncoder(nn.Module): ...@@ -431,9 +487,13 @@ class LayoutLMEncoder(nn.Module):
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions, output_attentions,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention: if self.config.add_cross_attention:
...@@ -445,11 +505,18 @@ class LayoutLMEncoder(nn.Module): ...@@ -445,11 +505,18 @@ class LayoutLMEncoder(nn.Module):
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None if v is not None
) )
return BaseModelOutputWithCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions,
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
......
...@@ -424,7 +424,6 @@ def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=Tru ...@@ -424,7 +424,6 @@ def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=Tru
return attention_mask return attention_mask
# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
def create_position_ids_from_input_ids(input_ids, padding_idx): def create_position_ids_from_input_ids(input_ids, padding_idx):
""" """
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
......
...@@ -347,6 +347,7 @@ class TapasSelfAttention(nn.Module): ...@@ -347,6 +347,7 @@ class TapasSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
...@@ -360,6 +361,7 @@ class TapasSelfAttention(nn.Module): ...@@ -360,6 +361,7 @@ class TapasSelfAttention(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_value=None,
output_attentions=False, output_attentions=False,
): ):
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
...@@ -367,17 +369,30 @@ class TapasSelfAttention(nn.Module): ...@@ -367,17 +369,30 @@ class TapasSelfAttention(nn.Module):
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be # and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to. # such that the encoder's padding tokens are not attended to.
if encoder_hidden_states is not None: is_cross_attention = encoder_hidden_states is not None
mixed_key_layer = self.key(encoder_hidden_states)
mixed_value_layer = self.value(encoder_hidden_states) if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else: else:
mixed_key_layer = self.key(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states))
mixed_value_layer = self.value(hidden_states) value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer) query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer) if self.is_decoder:
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
...@@ -404,6 +419,8 @@ class TapasSelfAttention(nn.Module): ...@@ -404,6 +419,8 @@ class TapasSelfAttention(nn.Module):
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs return outputs
...@@ -455,6 +472,7 @@ class TapasAttention(nn.Module): ...@@ -455,6 +472,7 @@ class TapasAttention(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_value=None,
output_attentions=False, output_attentions=False,
): ):
self_outputs = self.self( self_outputs = self.self(
...@@ -463,6 +481,7 @@ class TapasAttention(nn.Module): ...@@ -463,6 +481,7 @@ class TapasAttention(nn.Module):
head_mask, head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions, output_attentions,
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
...@@ -523,36 +542,60 @@ class TapasLayer(nn.Module): ...@@ -523,36 +542,60 @@ class TapasLayer(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_value=None,
output_attentions=False, output_attentions=False,
): ):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask, head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
) )
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
assert hasattr( assert hasattr(
self, "crossattention" self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention( cross_attention_outputs = self.crossattention(
attention_output, attention_output,
attention_mask, attention_mask,
head_mask, head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
cross_attn_past_key_value,
output_attentions, output_attentions,
) )
attention_output = cross_attention_outputs[0] attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward( layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
) )
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs return outputs
def feed_forward_chunk(self, attention_output): def feed_forward_chunk(self, attention_output):
...@@ -574,6 +617,8 @@ class TapasEncoder(nn.Module): ...@@ -574,6 +617,8 @@ class TapasEncoder(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
return_dict=True, return_dict=True,
...@@ -590,7 +635,7 @@ class TapasEncoder(nn.Module): ...@@ -590,7 +635,7 @@ class TapasEncoder(nn.Module):
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
return module(*inputs, output_attentions) return module(*inputs, past_key_values, output_attentions)
return custom_forward return custom_forward
...@@ -609,6 +654,7 @@ class TapasEncoder(nn.Module): ...@@ -609,6 +654,7 @@ class TapasEncoder(nn.Module):
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_values,
output_attentions, output_attentions,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
......
...@@ -150,7 +150,7 @@ class BartModelTester: ...@@ -150,7 +150,7 @@ class BartModelTester:
input_ids = inputs_dict["input_ids"] input_ids = inputs_dict["input_ids"]
# first forward pass # first forward pass
outputs = model(input_ids, use_cache=True) outputs = model(input_ids, attention_mask=inputs_dict["attention_mask"], use_cache=True)
output, past_key_values = outputs.to_tuple() output, past_key_values = outputs.to_tuple()
......
...@@ -260,6 +260,66 @@ class BertModelTester: ...@@ -260,6 +260,66 @@ class BertModelTester:
) )
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.is_decoder = True
config.add_cross_attention = True
model = BertLMHeadModel(config=config).to(torch_device).eval()
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
)["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)["hidden_states"][0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_for_next_sequence_prediction( def create_and_check_for_next_sequence_prediction(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
...@@ -454,6 +514,10 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -454,6 +514,10 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_for_causal_lm_as_decoder(*config_and_inputs) self.model_tester.create_and_check_model_for_causal_lm_as_decoder(*config_and_inputs)
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
def test_for_multiple_choice(self): def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
......
...@@ -25,6 +25,8 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, r ...@@ -25,6 +25,8 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, r
if is_torch_available(): if is_torch_available():
import torch
from transformers import BertGenerationConfig, BertGenerationDecoder, BertGenerationEncoder from transformers import BertGenerationConfig, BertGenerationDecoder, BertGenerationEncoder
...@@ -156,6 +158,64 @@ class BertGenerationEncoderTester: ...@@ -156,6 +158,64 @@ class BertGenerationEncoderTester:
) )
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
input_mask,
token_labels,
encoder_hidden_states,
encoder_attention_mask,
**kwargs,
):
config.is_decoder = True
config.add_cross_attention = True
model = BertGenerationDecoder(config=config).to(torch_device).eval()
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
)["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)["hidden_states"][0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_for_causal_lm( def create_and_check_for_causal_lm(
self, self,
config, config,
...@@ -203,6 +263,10 @@ class BertGenerationEncoderTest(ModelTesterMixin, GenerationTesterMixin, unittes ...@@ -203,6 +263,10 @@ class BertGenerationEncoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
def test_model_as_decoder_with_default_input_mask(self): def test_model_as_decoder_with_default_input_mask(self):
# This regression test was failing with PyTorch < 1.3 # This regression test was failing with PyTorch < 1.3
( (
......
...@@ -198,6 +198,74 @@ class RobertaModelTester: ...@@ -198,6 +198,74 @@ class RobertaModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.is_decoder = True
config.add_cross_attention = True
model = RobertaForCausalLM(config=config).to(torch_device).eval()
# make sure that ids don't start with pad token
mask = input_ids.ne(config.pad_token_id).long()
input_ids = input_ids * mask
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
# make sure that ids don't start with pad token
mask = next_tokens.ne(config.pad_token_id).long()
next_tokens = next_tokens * mask
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
)["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)["hidden_states"][0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_for_masked_lm( def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
...@@ -337,6 +405,10 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas ...@@ -337,6 +405,10 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
def test_for_masked_lm(self): def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment