Unverified Commit 9d37c56b authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Reformer] - Cache hidden states and buckets to speed up inference (#5578)

* fix merge rebase

* add intermediate reformer code

* save intermediate caching results

* save intermediate

* save intermediate results

* save intermediate

* upload next step

* fix generate tests

* make tests work

* add named tuple output

* Apply suggestions from code review

* fix use_cache for False case

* fix tensor to gpu

* fix tensor to gpu

* refactor

* refactor and make style
parent 0b6c255a
...@@ -18,8 +18,10 @@ ...@@ -18,8 +18,10 @@
import logging import logging
import sys import sys
from collections import namedtuple from collections import namedtuple
from dataclasses import dataclass
from functools import reduce from functools import reduce
from operator import mul from operator import mul
from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -32,6 +34,7 @@ from .configuration_reformer import ReformerConfig ...@@ -32,6 +34,7 @@ from .configuration_reformer import ReformerConfig
from .file_utils import ( from .file_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
DUMMY_MASK, DUMMY_MASK,
ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_callable, add_start_docstrings_to_callable,
...@@ -80,7 +83,18 @@ ReformerOutput = namedtuple("ReformerOutput", ["hidden_states", "attn_output", " ...@@ -80,7 +83,18 @@ ReformerOutput = namedtuple("ReformerOutput", ["hidden_states", "attn_output", "
ReformerBackwardOutput = namedtuple( ReformerBackwardOutput = namedtuple(
"ReformerBackwardOutput", ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"] "ReformerBackwardOutput", ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"]
) )
ReformerEncoderOutput = namedtuple("ReformerEncoderOutput", ["hidden_states", "all_hidden_states", "all_attentions"]) ReformerEncoderOutput = namedtuple(
"ReformerEncoderOutput", ["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"],
)
def _stable_argsort(vector, dim):
# this function scales the vector so that torch.argsort is stable.
# torch.argsort is not stable on its own
scale_offset = torch.arange(vector.shape[dim], device=vector.device).view(1, 1, -1)
scale_offset = scale_offset.expand(vector.shape)
scaled_vector = vector.shape[dim] * vector + (scale_offset % vector.shape[dim])
return torch.argsort(scaled_vector, dim=dim)
def _get_least_common_mult_chunk_len(config): def _get_least_common_mult_chunk_len(config):
...@@ -100,6 +114,23 @@ def _get_least_common_mult_chunk_len(config): ...@@ -100,6 +114,23 @@ def _get_least_common_mult_chunk_len(config):
) )
def _get_min_chunk_len(config):
attn_types = config.attn_layers
attn_types_set = set(attn_types)
if len(attn_types_set) == 1 and attn_types[0] == "lsh":
return config.lsh_attn_chunk_length
elif len(attn_types_set) == 1 and attn_types[0] == "local":
return config.local_attn_chunk_length
elif len(attn_types_set) == 2 and attn_types_set == set(["lsh", "local"]):
return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length)
else:
raise NotImplementedError(
"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {}. Select attn layer types from ['lsh', 'local'] only.".format(
config.attn_layers
)
)
class AxialPositionEmbeddings(nn.Module): class AxialPositionEmbeddings(nn.Module):
"""Constructs axial position embeddings. Useful for very long input """Constructs axial position embeddings. Useful for very long input
sequences to save memory and time. sequences to save memory and time.
...@@ -171,15 +202,23 @@ class AxialPositionEmbeddings(nn.Module): ...@@ -171,15 +202,23 @@ class AxialPositionEmbeddings(nn.Module):
) )
# compute how many columns are needed # compute how many columns are needed
required_pos_encodings_columns = -(-sequence_length // self.axial_pos_shape[1]) max_position_id = position_ids.max().item()
required_pos_encodings_columns = -(-(max_position_id + 1) // self.axial_pos_shape[1])
# cut to columns that are needed # cut to columns that are needed
position_encodings = torch.cat( position_encodings = torch.cat(
[weight[:, :required_pos_encodings_columns] for weight in broadcasted_weights], dim=-1 [weight[:, :required_pos_encodings_columns] for weight in broadcasted_weights], dim=-1
) )
position_encodings = torch.reshape(position_encodings, (batch_size, -1, position_encodings.shape[-1]))[ position_encodings = torch.reshape(position_encodings, (batch_size, -1, position_encodings.shape[-1]))
:, :sequence_length
] # select correct position encodings
position_encodings = torch.cat(
[
torch.index_select(position_encodings[i], 0, position_ids[i]).unsqueeze(0)
for i in range(batch_size)
],
dim=0,
)
return position_encodings return position_encodings
...@@ -213,7 +252,7 @@ class ReformerEmbeddings(nn.Module): ...@@ -213,7 +252,7 @@ class ReformerEmbeddings(nn.Module):
AxialPositionEmbeddings(config) if config.axial_pos_embds else PositionEmbeddings(config) AxialPositionEmbeddings(config) if config.axial_pos_embds else PositionEmbeddings(config)
) )
def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, start_idx_pos_encodings=0):
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
device = input_ids.device device = input_ids.device
...@@ -223,7 +262,9 @@ class ReformerEmbeddings(nn.Module): ...@@ -223,7 +262,9 @@ class ReformerEmbeddings(nn.Module):
seq_length = input_shape[1] seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device) position_ids = torch.arange(
start_idx_pos_encodings, start_idx_pos_encodings + seq_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).expand(input_shape) position_ids = position_ids.unsqueeze(0).expand(input_shape)
if inputs_embeds is None: if inputs_embeds is None:
...@@ -339,8 +380,10 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -339,8 +380,10 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
num_hashes=None, num_hashes=None,
output_attentions=False,
buckets=None, buckets=None,
past_buckets_states=None,
use_cache=False,
output_attentions=False,
**kwargs **kwargs
): ):
sequence_length = hidden_states.shape[1] sequence_length = hidden_states.shape[1]
...@@ -349,17 +392,72 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -349,17 +392,72 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# num hashes can optionally be overwritten by user # num hashes can optionally be overwritten by user
num_hashes = num_hashes if num_hashes is not None else self.num_hashes num_hashes = num_hashes if num_hashes is not None else self.num_hashes
do_cached_attention = use_cache and past_buckets_states[1] is not None
# check if cache shall be used and that hidden states are already cached
if do_cached_attention:
assert (
sequence_length == 1
), f"At the moment, auto-regressive language generation is only possible one word at a time. Make sure that input sequence length {sequence_length} equals 1, when `past_buckets_states` is passed."
past_buckets = past_buckets_states[0]
past_states = past_buckets_states[1]
# get query vector
query_vectors = self.query_key(hidden_states)
query_vectors = self._split_hidden_size_dim(
query_vectors, self.num_attention_heads, self.attention_head_size
)
if past_buckets is not None:
key_value_hidden_states, sorted_bucket_idx, buckets = self._get_relevant_hid_states_and_buckets(
query_vectors=query_vectors,
attention_mask=attention_mask,
num_hashes=num_hashes,
hidden_states=hidden_states,
past_states=past_states,
past_buckets=past_buckets,
)
query_key_vectors = self._query_per_attn_head(key_value_hidden_states)
value_vectors = self._value_per_attn_head(key_value_hidden_states)
# split key & value vectors by num hashes to apply
# self attention on each separately
query_key_vectors = self._split_seq_length_dim_to(
query_key_vectors, num_hashes, -1, self.num_attention_heads, self.attention_head_size,
)
value_vectors = self._split_seq_length_dim_to(
value_vectors, num_hashes, -1, self.num_attention_heads, self.attention_head_size,
)
# repeat query vectors across hash dimension
query_vectors = query_vectors.unsqueeze(2).repeat(1, 1, num_hashes, 1, 1)
else:
key_value_hidden_states = torch.cat([past_states, hidden_states], dim=1)
query_key_vectors = self.query_key(key_value_hidden_states)
value_vectors = self.value(key_value_hidden_states)
else:
# project hidden_states to query_key and value # project hidden_states to query_key and value
query_vectors = None
query_key_vectors = self.query_key(hidden_states) query_key_vectors = self.query_key(hidden_states)
value_vectors = self.value(hidden_states) value_vectors = self.value(hidden_states)
# free memory # if query key is not already split
del hidden_states if not do_cached_attention or past_buckets is None:
query_key_vectors = self._split_hidden_size_dim( query_key_vectors = self._split_hidden_size_dim(
query_key_vectors, self.num_attention_heads, self.attention_head_size query_key_vectors, self.num_attention_heads, self.attention_head_size
) )
value_vectors = self._split_hidden_size_dim(value_vectors, self.num_attention_heads, self.attention_head_size) value_vectors = self._split_hidden_size_dim(
value_vectors, self.num_attention_heads, self.attention_head_size
)
# cache buckets for next incremental decoding
if do_cached_attention and past_buckets is None and key_value_hidden_states.shape[1] >= self.chunk_length:
buckets = self._hash_vectors(query_key_vectors, num_hashes, attention_mask)
# free memory
del hidden_states
assert ( assert (
query_key_vectors.shape[-1] == self.attention_head_size query_key_vectors.shape[-1] == self.attention_head_size
...@@ -372,8 +470,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -372,8 +470,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
value_vectors.shape[-1], self.attention_head_size value_vectors.shape[-1], self.attention_head_size
) )
do_standard_self_attention = (sequence_length <= self.chunk_length) or (
use_cache and past_buckets_states[1] is not None
)
# LSH attention only makes sense if chunked attention should be performed # LSH attention only makes sense if chunked attention should be performed
if self.chunk_length < sequence_length: if not do_standard_self_attention:
# set `num_buckets` on the fly, recommended way to do it # set `num_buckets` on the fly, recommended way to do it
if self.num_buckets is None: if self.num_buckets is None:
self._set_num_buckets(sequence_length) self._set_num_buckets(sequence_length)
...@@ -382,6 +483,9 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -382,6 +483,9 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
if buckets is None: if buckets is None:
# hash query key vectors into buckets # hash query key vectors into buckets
buckets = self._hash_vectors(query_key_vectors, num_hashes, attention_mask) buckets = self._hash_vectors(query_key_vectors, num_hashes, attention_mask)
else:
# make sure buckets has correct shape for LSH attention
buckets = buckets.view(batch_size, self.num_attention_heads, num_hashes * sequence_length)
assert ( assert (
int(buckets.shape[-1]) == num_hashes * sequence_length int(buckets.shape[-1]) == num_hashes * sequence_length
...@@ -397,7 +501,6 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -397,7 +501,6 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# cluster query key value vectors according to hashed buckets # cluster query key value vectors according to hashed buckets
query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx_per_hash, num_hashes) query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx_per_hash, num_hashes)
value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx_per_hash, num_hashes) value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx_per_hash, num_hashes)
query_key_vectors = self._split_seq_length_dim_to( query_key_vectors = self._split_seq_length_dim_to(
query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
) )
...@@ -409,6 +512,9 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -409,6 +512,9 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
assert ( assert (
self.num_chunks_before == 0 and self.num_chunks_after == 0 self.num_chunks_before == 0 and self.num_chunks_after == 0
), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0." ), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
elif do_cached_attention and past_buckets is not None:
# use max sequence length
sorted_bucket_idx_per_hash = sorted_bucket_idx
else: else:
# get sequence length indices # get sequence length indices
sorted_bucket_idx_per_hash = torch.arange(sequence_length, device=query_key_vectors.device).repeat( sorted_bucket_idx_per_hash = torch.arange(sequence_length, device=query_key_vectors.device).repeat(
...@@ -418,25 +524,33 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -418,25 +524,33 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# scale key vectors # scale key vectors
key_vectors = self._len_and_dim_norm(query_key_vectors) key_vectors = self._len_and_dim_norm(query_key_vectors)
# set query_vectors to query key vectors if LSH self attention
query_vectors = query_vectors if query_vectors is not None else query_key_vectors
# free memory
del query_key_vectors
# get attention probs # get attention probs
out_vectors, logits, attention_probs = self._attend( out_vectors, logits, attention_probs = self._attend(
query_vectors=query_key_vectors, query_vectors=query_vectors,
key_vectors=key_vectors, key_vectors=key_vectors,
value_vectors=value_vectors, value_vectors=value_vectors,
sorted_bucket_idx_per_hash=sorted_bucket_idx_per_hash, sorted_bucket_idx_per_hash=sorted_bucket_idx_per_hash,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
sequence_length=sequence_length, do_standard_self_attention=do_standard_self_attention,
do_cached_attention=do_cached_attention,
) )
# free memory # free memory
del query_key_vectors, key_vectors, value_vectors del key_vectors, value_vectors
# re-order out_vectors and logits # re-order out_vectors and logits
if self.chunk_length < sequence_length: if not do_standard_self_attention:
# sort clusters back to correct ordering # sort clusters back to correct ordering
out_vectors, logits = ReverseSort.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx) out_vectors, logits = ReverseSort.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx)
if not do_standard_self_attention or (do_cached_attention and past_buckets is not None):
# sum up all hash rounds # sum up all hash rounds
if num_hashes > 1: if num_hashes > 1:
out_vectors = self._split_seq_length_dim_to( out_vectors = self._split_seq_length_dim_to(
...@@ -466,9 +580,28 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -466,9 +580,28 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
if output_attentions is False: if output_attentions is False:
attention_probs = () attention_probs = ()
if buckets is not None:
buckets = buckets.view(batch_size, self.num_attention_heads, num_hashes, -1)
return LSHSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets) return LSHSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets)
def _hash_vectors(self, vectors, num_hashes, attention_mask): def _query_per_attn_head(self, hidden_states):
per_head_query_key = self.query_key.weight.reshape(
self.num_attention_heads, self.attention_head_size, self.hidden_size
).transpose(-2, -1)
# only relevant for inference and no bias => we can use einsum here
query_key_vectors = torch.einsum("balh,ahr->balr", hidden_states, per_head_query_key)
return query_key_vectors
def _value_per_attn_head(self, hidden_states):
per_head_value = self.value.weight.reshape(
self.num_attention_heads, self.attention_head_size, self.hidden_size
).transpose(-2, -1)
# only relevant for inference and no bias => we can use einsum here
value_vectors = torch.einsum("balh,ahr->balr", hidden_states, per_head_value)
return value_vectors
def _hash_vectors(self, vectors, num_hashes, attention_mask, increase_num_buckets=False):
batch_size = vectors.shape[0] batch_size = vectors.shape[0]
# See https://arxiv.org/pdf/1509.02897.pdf # See https://arxiv.org/pdf/1509.02897.pdf
...@@ -514,7 +647,6 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -514,7 +647,6 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
rotated_vectors_factor = rotated_vectors[..., cur_sum : cur_sum + (bucket_factor // 2)] rotated_vectors_factor = rotated_vectors[..., cur_sum : cur_sum + (bucket_factor // 2)]
cur_sum = cur_sum + bucket_factor // 2 cur_sum = cur_sum + bucket_factor // 2
rotated_vectors_factor = torch.cat([rotated_vectors_factor, -rotated_vectors_factor], dim=-1) rotated_vectors_factor = torch.cat([rotated_vectors_factor, -rotated_vectors_factor], dim=-1)
if buckets is None: if buckets is None:
buckets = torch.argmax(rotated_vectors_factor, dim=-1) buckets = torch.argmax(rotated_vectors_factor, dim=-1)
else: else:
...@@ -522,7 +654,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -522,7 +654,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
cur_product = cur_product * bucket_factor cur_product = cur_product * bucket_factor
if attention_mask is not None: if attention_mask is not None and (attention_mask.sum().item() < batch_size * attention_mask.shape[-1]):
# add an extra bucket for padding tokens only # add an extra bucket for padding tokens only
num_buckets = num_buckets + 1 num_buckets = num_buckets + 1
# assign padding tokens extra bucket # assign padding tokens extra bucket
...@@ -530,6 +662,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -530,6 +662,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
buckets = torch.where( buckets = torch.where(
buckets_mask, buckets, torch.tensor(num_buckets - 1, dtype=torch.long, device=buckets.device) buckets_mask, buckets, torch.tensor(num_buckets - 1, dtype=torch.long, device=buckets.device)
) )
elif increase_num_buckets:
num_buckets = num_buckets + 1
# buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len). # buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len).
# Next we add offsets so that bucket numbers from different hashing rounds don't overlap. # Next we add offsets so that bucket numbers from different hashing rounds don't overlap.
...@@ -545,20 +679,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -545,20 +679,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(self, sequence_length, buckets, num_hashes): def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(self, sequence_length, buckets, num_hashes):
# no gradients are needed # no gradients are needed
with torch.no_grad(): with torch.no_grad():
batch_size = buckets.shape[0] # hash-based sort
sorted_bucket_idx = _stable_argsort(buckets, dim=-1)
# arange and expand
orig_indices = torch.arange(num_hashes * sequence_length, device=buckets.device).view(1, 1, -1)
orig_indices = orig_indices.expand(batch_size, self.num_attention_heads, orig_indices.shape[-1])
# scale buckets
scaled_buckets = sequence_length * buckets + (orig_indices % sequence_length)
# remove gradient
scaled_buckets = scaled_buckets.detach()
# Hash-based sort
sorted_bucket_idx = torch.argsort(scaled_buckets, dim=-1)
# create simple indices to scatter to, to have undo sort # create simple indices to scatter to, to have undo sort
indices = ( indices = (
...@@ -600,26 +722,37 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -600,26 +722,37 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
sorted_bucket_idx_per_hash, sorted_bucket_idx_per_hash,
attention_mask, attention_mask,
head_mask, head_mask,
sequence_length, do_standard_self_attention,
do_cached_attention,
): ):
# look at previous and following chunks if chunked attention # look at previous and following chunks if chunked attention
if self.chunk_length < sequence_length: if not do_standard_self_attention:
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after)
# get logits and dots # get logits and dots
# (BS, NumAttn, NumHash x NumChunk, Chunk_L x Hidden),(BS, NumAttn, NumHash x NumChunk, Chunk_L * (Num_bef + Num_aft + 1) x Hidden) -> (BS, NumAttn, NumHash x NumChunk, Chunk_L, Chunk_L * (1 + Num_bef + Num_aft))
query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))
# free memory # free memory
del query_vectors, key_vectors del query_vectors, key_vectors
# if chunked attention split bucket idxs to query and key # if chunked attention split bucket idxs to query and key
if self.chunk_length < sequence_length: if not do_standard_self_attention:
query_bucket_idx = self._split_seq_length_dim_to( query_bucket_idx = self._split_seq_length_dim_to(
sorted_bucket_idx_per_hash, -1, self.chunk_length, self.num_attention_heads sorted_bucket_idx_per_hash, -1, self.chunk_length, self.num_attention_heads
) )
key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after) key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after)
elif do_cached_attention and query_key_dots.ndim > 4:
key_value_bucket_idx = sorted_bucket_idx_per_hash
query_bucket_idx = (
key_value_bucket_idx.new_ones(key_value_bucket_idx.shape[:-1] + (1,)) * key_value_bucket_idx.max()
)
elif do_cached_attention and query_key_dots.ndim <= 4:
query_bucket_idx = (query_key_dots.shape[-1] - 1) * torch.ones_like(query_key_dots)[:, :, :, -1]
key_value_bucket_idx = torch.arange(
query_key_dots.shape[-1], dtype=torch.long, device=query_key_dots.device
)[None, None, :].expand(query_bucket_idx.shape[:2] + (-1,))
else: else:
query_bucket_idx = key_value_bucket_idx = sorted_bucket_idx_per_hash query_bucket_idx = key_value_bucket_idx = sorted_bucket_idx_per_hash
...@@ -631,8 +764,13 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -631,8 +764,13 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
self_mask_value = self.self_mask_value_float32 self_mask_value = self.self_mask_value_float32
mask_value = self.mask_value_float32 mask_value = self.mask_value_float32
if not do_cached_attention:
mask = self._compute_attn_mask( mask = self._compute_attn_mask(
query_bucket_idx, key_value_bucket_idx, attention_mask, query_key_dots.shape, sequence_length query_bucket_idx,
key_value_bucket_idx,
attention_mask,
query_key_dots.shape,
do_standard_self_attention,
) )
if mask is not None: if mask is not None:
...@@ -682,19 +820,20 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -682,19 +820,20 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
del value_vectors del value_vectors
# merge chunk length # merge chunk length
if self.chunk_length < sequence_length: if out_vectors.ndim > 4:
logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1) logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1)
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
return out_vectors, logits, attention_probs return out_vectors, logits, attention_probs
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dot_shape, sequence_length): def _compute_attn_mask(
self, query_indices, key_indices, attention_mask, query_key_dot_shape, do_standard_self_attention
):
# attention mask for LSH # attention mask for LSH
if attention_mask is not None: if attention_mask is not None:
# if chunked attention, the attention mask has to correspond to LSH order # if chunked attention, the attention mask has to correspond to LSH order
attention_mask = attention_mask.to(torch.uint8)[:, None, :] attention_mask = attention_mask.to(torch.uint8)[:, None, :]
if sequence_length > self.chunk_length: if not do_standard_self_attention:
# expand attn_mask to fit with key_value_bucket_idx shape # expand attn_mask to fit with key_value_bucket_idx shape
attention_mask = attention_mask[:, None, :] attention_mask = attention_mask[:, None, :]
attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,)) attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,))
...@@ -715,6 +854,102 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -715,6 +854,102 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
return attention_mask return attention_mask
def _get_relevant_hid_states_and_buckets(
self, query_vectors, attention_mask, num_hashes, hidden_states, past_states, past_buckets
):
# concat hidden states
hidden_states = torch.cat([past_states, hidden_states], dim=1)
# batch_size hidden
batch_size = hidden_states.shape[0]
sequence_length = hidden_states.shape[1]
# check if cached buckets include pad bucket
max_bucket = self.num_buckets if isinstance(self.num_buckets, int) else reduce(mul, self.num_buckets)
# if pad bucket was cached => need to increase num buckets for caching
increase_num_buckets = past_buckets.max() > num_hashes * max_bucket - 1
# retrieve query buckets
query_buckets = self._hash_vectors(
query_vectors, num_hashes, attention_mask, increase_num_buckets=increase_num_buckets
)
# concat buckets
concat_buckets = torch.cat([past_buckets, query_buckets.unsqueeze(-1)], dim=-1)
# hash-based sort
bucket_idx = _stable_argsort(concat_buckets, dim=-1)
# bucket_idx has shape: BatchSize x NumAttnHeads x NumHashes x SequenceLength
assert bucket_idx.shape == (
batch_size,
self.num_attention_heads,
num_hashes,
sequence_length,
), f"bucket_idx should have shape {(batch_size, self.num_attention_heads, num_hashes, sequence_length)}, but has shape {bucket_idx.shape}."
# find indices of new bucket indices
relevant_bucket_idx = (bucket_idx == (bucket_idx.shape[-1] - 1)).nonzero()
# expand relevant bucket indices to its chunks
relevant_bucket_idx_chunk = self._expand_to_indices_in_relevant_chunk(relevant_bucket_idx, sequence_length)
relevant_bucket_idx_chunk = bucket_idx[tuple(relevant_bucket_idx_chunk.transpose(0, 1))]
# adapt bucket_idx for batch and hidden states for index select
bucket_idx_batch_offset = sequence_length * (
batch_size
* torch.arange(relevant_bucket_idx_chunk.shape[-1], device=hidden_states.device, dtype=torch.long)
// relevant_bucket_idx_chunk.shape[-1]
)
# add batch offset
relevant_bucket_idx_chunk_all_batch = relevant_bucket_idx_chunk + bucket_idx_batch_offset
hidden_states = hidden_states.reshape((-1, self.hidden_size))
# select all relevant hidden states
relevant_hidden_states = hidden_states.index_select(0, relevant_bucket_idx_chunk_all_batch)
# reshape hidden states and bucket_idx to correct output
relevant_hidden_states = relevant_hidden_states.reshape(
batch_size, self.num_attention_heads, -1, self.hidden_size
)
relevant_bucket_idx_chunk = relevant_bucket_idx_chunk.reshape(
batch_size, self.num_attention_heads, num_hashes, -1
)
assert (
relevant_hidden_states.shape[2]
== (self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length * num_hashes
), f"There should be {(self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length * num_hashes} `hidden_states`, there are {relevant_hidden_states.shape[2]} `hidden_states`."
assert (
relevant_bucket_idx_chunk.shape[-1]
== (self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length
), f"There should be {(self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length} `hidden_states`, there are {relevant_bucket_idx_chunk.shape[-1]} `bucket_idx`."
return relevant_hidden_states, relevant_bucket_idx_chunk, query_buckets
def _expand_to_indices_in_relevant_chunk(self, indices, sequence_length):
# get relevant indices of where chunk starts and its size
start_indices_chunk = ((indices[:, -1] // self.chunk_length) - self.num_chunks_before) * self.chunk_length
total_chunk_size = self.chunk_length * (1 + self.num_chunks_before + self.num_chunks_after)
# expand start indices and add correct chunk offset via arange
expanded_start_indices = start_indices_chunk.unsqueeze(-1).expand(indices.shape[0], total_chunk_size)
chunk_sequence_indices = expanded_start_indices + torch.arange(
total_chunk_size, device=indices.device, dtype=torch.long
).unsqueeze(0).expand(indices.shape[0], total_chunk_size)
# make sure that circular logic holds via % seq len
chunk_sequence_indices = chunk_sequence_indices.flatten() % sequence_length
# expand indices and set indices correctly
indices = indices.unsqueeze(1).expand((indices.shape[0], total_chunk_size, -1)).flatten(0, 1).clone()
indices[:, -1] = chunk_sequence_indices
return indices
def _len_and_dim_norm(self, vectors): def _len_and_dim_norm(self, vectors):
""" """
length and attention head size dim normalization length and attention head size dim normalization
...@@ -803,10 +1038,38 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -803,10 +1038,38 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
self.register_buffer("mask_value_float16", torch.tensor(-1e4)) self.register_buffer("mask_value_float16", torch.tensor(-1e4))
self.register_buffer("mask_value_float32", torch.tensor(-1e9)) self.register_buffer("mask_value_float32", torch.tensor(-1e9))
def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, **kwargs): def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
past_buckets_states=None,
use_cache=False,
output_attentions=False,
**kwargs
):
sequence_length = hidden_states.shape[1] sequence_length = hidden_states.shape[1]
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
# check if cache shall be used and that hidden states are already cached
if use_cache and past_buckets_states[1] is not None:
assert (
past_buckets_states[0] is None
), "LocalSelfAttention should not make use of `buckets`. There seems to be an error when caching hidden_states_and_buckets."
key_value_hidden_states = self._retrieve_relevant_hidden_states(
past_buckets_states[1], self.chunk_length, self.num_chunks_before
)
key_value_hidden_states = torch.cat([key_value_hidden_states, hidden_states], dim=1)
# only query vector for last token
query_vectors = self.query(hidden_states)
# compute key and value for relevant chunk
key_vectors = self.key(key_value_hidden_states)
value_vectors = self.value(key_value_hidden_states)
# free memory
del key_value_hidden_states
else:
# project hidden_states to query, key and value # project hidden_states to query, key and value
query_vectors = self.query(hidden_states) query_vectors = self.query(hidden_states)
key_vectors = self.key(hidden_states) key_vectors = self.key(hidden_states)
...@@ -848,8 +1111,11 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -848,8 +1111,11 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
batch_size, self.num_attention_heads, 1 batch_size, self.num_attention_heads, 1
) )
# if one should do normal n^2 self-attention
do_standard_self_attention = sequence_length <= self.chunk_length
# if input should be chunked # if input should be chunked
if self.chunk_length < sequence_length: if not do_standard_self_attention:
# chunk vectors # chunk vectors
# B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size
query_vectors = self._split_seq_length_dim_to( query_vectors = self._split_seq_length_dim_to(
...@@ -880,7 +1146,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -880,7 +1146,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
del query_vectors, key_vectors del query_vectors, key_vectors
mask = self._compute_attn_mask( mask = self._compute_attn_mask(
query_indices, key_indices, attention_mask, query_key_dots.shape, sequence_length query_indices, key_indices, attention_mask, query_key_dots.shape, do_standard_self_attention
) )
if mask is not None: if mask is not None:
...@@ -916,7 +1182,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -916,7 +1182,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
del value_vectors del value_vectors
# merge chunk length # merge chunk length
if self.chunk_length < sequence_length: if not do_standard_self_attention:
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,) assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,)
...@@ -928,13 +1194,15 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -928,13 +1194,15 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs) return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs)
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape, sequence_length): def _compute_attn_mask(
self, query_indices, key_indices, attention_mask, query_key_dots_shape, do_standard_self_attention
):
# chunk attention mask and look before and after # chunk attention mask and look before and after
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.to(torch.uint8)[:, None, :] attention_mask = attention_mask.to(torch.uint8)[:, None, :]
if self.chunk_length < sequence_length: if not do_standard_self_attention:
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1) attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
attention_mask = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) attention_mask = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after)
# create attn_mask # create attn_mask
...@@ -952,6 +1220,11 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -952,6 +1220,11 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
return attention_mask return attention_mask
@staticmethod
def _retrieve_relevant_hidden_states(previous_hidden_states, chunk_length, num_chunks_before):
start_position = ((previous_hidden_states.shape[1] // chunk_length) - num_chunks_before) * chunk_length
return previous_hidden_states[:, start_position:]
class ReformerSelfOutput(nn.Module): class ReformerSelfOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
...@@ -999,21 +1272,31 @@ class ReformerAttention(nn.Module): ...@@ -999,21 +1272,31 @@ class ReformerAttention(nn.Module):
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
num_hashes=None, num_hashes=None,
past_buckets_states=None,
use_cache=False,
orig_sequence_length=None,
output_attentions=False, output_attentions=False,
buckets=None, buckets=None,
): ):
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
# make sure cached hidden states is set to None for backward pass
if past_buckets_states is not None:
past_buckets_states_layer = past_buckets_states[self.layer_id]
else:
past_buckets_states_layer = None
# use cached buckets for backprob if buckets not None for LSHSelfAttention # use cached buckets for backprob if buckets not None for LSHSelfAttention
self_attention_outputs = self.self_attention( self_attention_outputs = self.self_attention(
hidden_states=hidden_states, hidden_states=hidden_states,
head_mask=head_mask, head_mask=head_mask,
attention_mask=attention_mask, attention_mask=attention_mask,
num_hashes=num_hashes, num_hashes=num_hashes,
past_buckets_states=past_buckets_states_layer,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
buckets=buckets, buckets=buckets,
) )
attention_output = self.output(self_attention_outputs.hidden_states)
# add buckets if necessary # add buckets if necessary
if hasattr(self_attention_outputs, "buckets"): if hasattr(self_attention_outputs, "buckets"):
...@@ -1021,6 +1304,28 @@ class ReformerAttention(nn.Module): ...@@ -1021,6 +1304,28 @@ class ReformerAttention(nn.Module):
else: else:
buckets = None buckets = None
# cache hidden states for future use
if use_cache:
if past_buckets_states[self.layer_id][0] is None:
# padded input should not be cached
past_buckets = (
buckets[:, :, :, :orig_sequence_length]
if (buckets is not None and orig_sequence_length > 1)
else buckets
)
else:
past_buckets = torch.cat([past_buckets_states[self.layer_id][0], buckets], dim=-1)
if past_buckets_states[self.layer_id][1] is None:
# padded input should not be cached
past_states = hidden_states[:, :orig_sequence_length]
else:
past_states = torch.cat([past_buckets_states[self.layer_id][1], hidden_states], dim=1)
past_buckets_states[self.layer_id] = (past_buckets, past_states)
# compute attention feed forward output
attention_output = self.output(self_attention_outputs.hidden_states)
return AttentionOutput( return AttentionOutput(
hidden_states=attention_output, attention_probs=self_attention_outputs.attention_probs, buckets=buckets, hidden_states=attention_output, attention_probs=self_attention_outputs.attention_probs, buckets=buckets,
) )
...@@ -1137,6 +1442,9 @@ class ReformerLayer(nn.Module): ...@@ -1137,6 +1442,9 @@ class ReformerLayer(nn.Module):
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
num_hashes=None, num_hashes=None,
past_buckets_states=None,
use_cache=False,
orig_sequence_length=None,
output_attentions=False, output_attentions=False,
): ):
with torch.no_grad(): with torch.no_grad():
...@@ -1149,6 +1457,9 @@ class ReformerLayer(nn.Module): ...@@ -1149,6 +1457,9 @@ class ReformerLayer(nn.Module):
head_mask=head_mask, head_mask=head_mask,
attention_mask=attention_mask, attention_mask=attention_mask,
num_hashes=num_hashes, num_hashes=num_hashes,
past_buckets_states=past_buckets_states,
use_cache=use_cache,
orig_sequence_length=orig_sequence_length,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
attn_output = attn_outputs.hidden_states attn_output = attn_outputs.hidden_states
...@@ -1254,6 +1565,9 @@ class _ReversibleFunction(Function): ...@@ -1254,6 +1565,9 @@ class _ReversibleFunction(Function):
num_hashes, num_hashes,
all_hidden_states, all_hidden_states,
all_attentions, all_attentions,
past_buckets_states,
use_cache,
orig_sequence_length,
output_hidden_states, output_hidden_states,
output_attentions, output_attentions,
): ):
...@@ -1262,7 +1576,7 @@ class _ReversibleFunction(Function): ...@@ -1262,7 +1576,7 @@ class _ReversibleFunction(Function):
# split duplicated tensor # split duplicated tensor
hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1) hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1)
for layer, layer_head_mask in zip(layers, head_mask): for layer_id, (layer, layer_head_mask) in enumerate(zip(layers, head_mask)):
if output_hidden_states is True: if output_hidden_states is True:
all_hidden_states.append(hidden_states) all_hidden_states.append(hidden_states)
...@@ -1272,8 +1586,12 @@ class _ReversibleFunction(Function): ...@@ -1272,8 +1586,12 @@ class _ReversibleFunction(Function):
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=layer_head_mask, head_mask=layer_head_mask,
num_hashes=num_hashes, num_hashes=num_hashes,
past_buckets_states=past_buckets_states,
use_cache=use_cache,
orig_sequence_length=orig_sequence_length,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
attn_output = layer_outputs.attn_output attn_output = layer_outputs.attn_output
hidden_states = layer_outputs.hidden_states hidden_states = layer_outputs.hidden_states
all_buckets = all_buckets + (layer_outputs.buckets,) all_buckets = all_buckets + (layer_outputs.buckets,)
...@@ -1339,7 +1657,7 @@ class _ReversibleFunction(Function): ...@@ -1339,7 +1657,7 @@ class _ReversibleFunction(Function):
# num of return vars has to match num of forward() args # num of return vars has to match num of forward() args
# return gradient for hidden_states arg and None for other args # return gradient for hidden_states arg and None for other args
return grad_hidden_states, None, None, None, None, None, None, None, None return grad_hidden_states, None, None, None, None, None, None, None, None, None, None, None
class ReformerEncoder(nn.Module): class ReformerEncoder(nn.Module):
...@@ -1358,6 +1676,9 @@ class ReformerEncoder(nn.Module): ...@@ -1358,6 +1676,9 @@ class ReformerEncoder(nn.Module):
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
num_hashes=None, num_hashes=None,
past_buckets_states=None,
use_cache=False,
orig_sequence_length=None,
output_hidden_states=False, output_hidden_states=False,
output_attentions=False, output_attentions=False,
): ):
...@@ -1365,6 +1686,10 @@ class ReformerEncoder(nn.Module): ...@@ -1365,6 +1686,10 @@ class ReformerEncoder(nn.Module):
all_hidden_states = [] all_hidden_states = []
all_attentions = [] all_attentions = []
# init cached hidden states if necessary
if past_buckets_states is None:
past_buckets_states = [((None), (None)) for i in range(len(self.layers))]
# concat same tensor for reversible ResNet # concat same tensor for reversible ResNet
hidden_states = torch.cat([hidden_states, hidden_states], dim=-1) hidden_states = torch.cat([hidden_states, hidden_states], dim=-1)
hidden_states = _ReversibleFunction.apply( hidden_states = _ReversibleFunction.apply(
...@@ -1375,6 +1700,9 @@ class ReformerEncoder(nn.Module): ...@@ -1375,6 +1700,9 @@ class ReformerEncoder(nn.Module):
num_hashes, num_hashes,
all_hidden_states, all_hidden_states,
all_attentions, all_attentions,
past_buckets_states,
use_cache,
orig_sequence_length,
output_hidden_states, output_hidden_states,
output_attentions, output_attentions,
) )
...@@ -1386,7 +1714,10 @@ class ReformerEncoder(nn.Module): ...@@ -1386,7 +1714,10 @@ class ReformerEncoder(nn.Module):
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
return ReformerEncoderOutput( return ReformerEncoderOutput(
hidden_states=hidden_states, all_hidden_states=all_hidden_states, all_attentions=all_attentions hidden_states=hidden_states,
all_hidden_states=all_hidden_states,
all_attentions=all_attentions,
past_buckets_states=past_buckets_states,
) )
...@@ -1448,6 +1779,85 @@ class ReformerPreTrainedModel(PreTrainedModel): ...@@ -1448,6 +1779,85 @@ class ReformerPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
@dataclass
class ReformerModelOutput(ModelOutput):
"""
Output type of :class:`~transformers.ReformerModel`.
Args:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_predict, hidden_size)`):
Sequence of hidden-states at the last layer of the model.
``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then
``num_predict`` corresponds to ``sequence_length``.
past_buckets_states (:obj:`List[Tuple(torch.LongTensor, torch.FloatTensor)]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tuple(torch.LongTensor, torch.FloatTensor` of length :obj:`config.n_layers`, with :obj:`tuple(0)` being the previous `buckets` of shape
:obj:`(batch_size, num_heads, num_hashes, sequence_length)`)
and :obj:`tuple(1)` being the previous `hidden_states` of shape
:obj:`(batch_size, sequence_length, hidden_size)`).
Contains pre-computed buckets and hidden-states that can be used (see
``past_buckets_states`` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor
past_buckets_states: Optional[List[Tuple[torch.LongTensor, torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class ReformerModelWithLMHeadOutput(ModelOutput):
"""
Output type of :class:`~transformers.ReformerModelWithLMHead`.
Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
Language modeling loss (for next-token prediction).
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_predict, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then
``num_predict`` corresponds to ``sequence_length``.
past_buckets_states (:obj:`List[Tuple(torch.LongTensor, torch.FloatTensor)]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tuple(torch.LongTensor, torch.FloatTensor` of length :obj:`config.n_layers`, with :obj:`tuple(0)` being the previous `buckets` of shape
:obj:`(batch_size, num_heads, num_hashes, sequence_length)`)
and :obj:`tuple(1)` being the previous `hidden_states` of shape
:obj:`(batch_size, sequence_length, hidden_size)`).
Contains pre-computed buckets and hidden-states that can be used (see
``past_buckets_states`` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
logits: torch.FloatTensor
past_buckets_states: Optional[List[Tuple[torch.LongTensor, torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
REFORMER_START_DOCSTRING = r""" REFORMER_START_DOCSTRING = r"""
Reformer was proposed in `Reformer: The Efficient Transformer <https://arxiv.org/abs/2001.0445>`__ Reformer was proposed in `Reformer: The Efficient Transformer <https://arxiv.org/abs/2001.0445>`__
by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
...@@ -1499,6 +1909,15 @@ REFORMER_INPUTS_DOCSTRING = r""" ...@@ -1499,6 +1909,15 @@ REFORMER_INPUTS_DOCSTRING = r"""
bucketing. Setting `num_hashes` overwrites the default `num_hashes` defined bucketing. Setting `num_hashes` overwrites the default `num_hashes` defined
in `config.num_hashes`. in `config.num_hashes`.
For more information, see `num_hashes` in :class:`transformers.ReformerConfig`. For more information, see `num_hashes` in :class:`transformers.ReformerConfig`.
past_buckets_states (:obj:`List[Tuple(torch.LongTensor, torch.FloatTensor)]`, `optional`, defaults `None`):
List of :obj:`tuple(torch.LongTensor, torch.FloatTensor` of length :obj:`config.n_layers`, with :obj:`tuple(0)` being the previous `buckets` of shape
:obj:`(batch_size, num_heads, num_hashes, sequence_length)`)
and :obj:`tuple(1)` being the previous `hidden_states` of shape
:obj:`(batch_size, sequence_length, hidden_size)`).
List of tuples that contains all previous computed hidden states and buckets (only relevant for LSH Self-Attention). Can be used to speed up sequential decoding.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the ``past_buckets_states`` of all attention layers are returned.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`): output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
...@@ -1554,10 +1973,39 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1554,10 +1973,39 @@ class ReformerModel(ReformerPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
num_hashes=None, num_hashes=None,
past_buckets_states=None,
use_cache=None,
output_hidden_states=None, output_hidden_states=None,
output_attentions=None, output_attentions=None,
return_tuple=None, return_tuple=None,
): ):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_buckets_states (:obj:`List[Tuple(torch.LongTensor, torch.FloatTensor)]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tuple(torch.LongTensor, torch.FloatTensor` of length :obj:`config.n_layers`, with :obj:`tuple(0)` being the previous `buckets` of shape
:obj:`(batch_size, num_heads, num_hashes, sequence_length)`)
and :obj:`tuple(1)` being the previous `hidden_states` of shape
:obj:`(batch_size, sequence_length, hidden_size)`).
Contains pre-computed buckets and hidden-states that can be used (see
``past_buckets_states`` input) to speed up sequential decoding.
all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -1579,6 +2027,9 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1579,6 +2027,9 @@ class ReformerModel(ReformerPreTrainedModel):
len(input_shape) == 2 len(input_shape) == 2
), "`input_ids` have be of shape `[batch_size, sequence_length]`, but got shape: {}".format(input_shape) ), "`input_ids` have be of shape `[batch_size, sequence_length]`, but got shape: {}".format(input_shape)
if past_buckets_states is not None:
assert not self.training, "`past_buckets_states` can only be used for inference, not for training`."
# prepare head mask # prepare head mask
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers, is_attention_chunked=True) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers, is_attention_chunked=True)
...@@ -1587,8 +2038,12 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1587,8 +2038,12 @@ class ReformerModel(ReformerPreTrainedModel):
# if needs padding # if needs padding
least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config) least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config)
min_chunk_length = _get_min_chunk_len(self.config)
must_pad_to_match_chunk_length = ( must_pad_to_match_chunk_length = (
input_shape[-1] % least_common_mult_chunk_length != 0 and input_shape[-1] > least_common_mult_chunk_length input_shape[-1] % least_common_mult_chunk_length != 0
and input_shape[-1] > min_chunk_length
and past_buckets_states is None
) )
if must_pad_to_match_chunk_length: if must_pad_to_match_chunk_length:
...@@ -1613,13 +2068,27 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1613,13 +2068,27 @@ class ReformerModel(ReformerPreTrainedModel):
device=device, device=device,
) )
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) # start index for postion encoding depends on incremental decoding
if past_buckets_states is not None:
start_idx_pos_encodings = past_buckets_states[0][1].shape[1]
else:
start_idx_pos_encodings = 0
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
start_idx_pos_encodings=start_idx_pos_encodings,
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
hidden_states=embedding_output, hidden_states=embedding_output,
head_mask=head_mask, head_mask=head_mask,
attention_mask=attention_mask, attention_mask=attention_mask,
num_hashes=num_hashes, num_hashes=num_hashes,
past_buckets_states=past_buckets_states,
use_cache=use_cache,
orig_sequence_length=orig_sequence_length,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -1629,12 +2098,18 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1629,12 +2098,18 @@ class ReformerModel(ReformerPreTrainedModel):
if must_pad_to_match_chunk_length: if must_pad_to_match_chunk_length:
sequence_output = sequence_output[:, :orig_sequence_length] sequence_output = sequence_output[:, :orig_sequence_length]
past_buckets_states = encoder_outputs.past_buckets_states if use_cache else None
hidden_states = encoder_outputs.all_hidden_states if output_hidden_states else None hidden_states = encoder_outputs.all_hidden_states if output_hidden_states else None
attentions = encoder_outputs.all_attentions if output_attentions else None attentions = encoder_outputs.all_attentions if output_attentions else None
if return_tuple: if return_tuple:
return tuple(v for v in [sequence_output, hidden_states, attentions] if v is not None) return tuple(v for v in [sequence_output, past_buckets_states, hidden_states, attentions] if v is not None)
return BaseModelOutput(last_hidden_state=sequence_output, hidden_states=hidden_states, attentions=attentions) return ReformerModelOutput(
last_hidden_state=sequence_output,
past_buckets_states=past_buckets_states,
hidden_states=hidden_states,
attentions=attentions,
)
def _pad_to_mult_of_chunk_length( def _pad_to_mult_of_chunk_length(
self, self,
...@@ -1659,13 +2134,9 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1659,13 +2134,9 @@ class ReformerModel(ReformerPreTrainedModel):
# Extend `attention_mask` # Extend `attention_mask`
if attention_mask is not None: if attention_mask is not None:
attention_mask = torch.cat( pad_attention_mask = torch.zeros(input_shape[0], padding_length, device=device, dtype=attention_mask.dtype)
[
attention_mask, attention_mask = torch.cat([attention_mask, pad_attention_mask], dim=-1)
torch.zeros(input_shape[0], padding_length, device=device, dtype=attention_mask.dtype,),
],
dim=-1,
)
else: else:
attention_mask = torch.cat( attention_mask = torch.cat(
[ [
...@@ -1698,7 +2169,14 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1698,7 +2169,14 @@ class ReformerModel(ReformerPreTrainedModel):
class ReformerModelWithLMHead(ReformerPreTrainedModel): class ReformerModelWithLMHead(ReformerPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
assert config.is_decoder, "If you want to use `ReformerLMHeadModel` make sure that `is_decoder=True`." assert config.is_decoder, "If you want to use `ReformerModelWithLMHead` make sure that `is_decoder=True`."
assert (
"local" not in self.config.attn_layers or config.local_num_chunks_after == 0
), f"If causal mask is enabled, make sure that `config.local_num_chunks_after` is set to 0 and not {config.local_num_chunks_after}."
assert (
"lsh" not in self.config.attn_layers or config.lsh_num_chunks_after == 0
), f"If causal mask is enabled, make sure that `config.lsh_num_chunks_after` is set to 1 and not {config.lsh_num_chunks_after}."
self.reformer = ReformerModel(config) self.reformer = ReformerModel(config)
self.lm_head = ReformerOnlyLMHead(config) self.lm_head = ReformerOnlyLMHead(config)
...@@ -1726,10 +2204,12 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): ...@@ -1726,10 +2204,12 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
num_hashes=None, num_hashes=None,
labels=None, past_buckets_states=None,
use_cache=None,
output_hidden_states=None, output_hidden_states=None,
output_attentions=None, output_attentions=None,
return_tuple=None, return_tuple=None,
labels=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -1747,6 +2227,8 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): ...@@ -1747,6 +2227,8 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
num_hashes=num_hashes, num_hashes=num_hashes,
past_buckets_states=past_buckets_states,
use_cache=use_cache,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
output_attentions=output_attentions, output_attentions=output_attentions,
return_tuple=return_tuple, return_tuple=return_tuple,
...@@ -1768,22 +2250,44 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): ...@@ -1768,22 +2250,44 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
output = (logits,) + reformer_outputs[1:] output = (logits,) + reformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return CausalLMOutput( return ReformerModelWithLMHeadOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,
past_buckets_states=reformer_outputs.past_buckets_states,
hidden_states=reformer_outputs.hidden_states, hidden_states=reformer_outputs.hidden_states,
attentions=reformer_outputs.attentions, attentions=reformer_outputs.attentions,
) )
def prepare_inputs_for_generation(self, input_ids, past, **kwargs): def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
# TODO(PVP): Add smart caching # only last token for inputs_ids if past is defined in kwargs
inputs_dict = {"input_ids": input_ids} if past is not None:
input_ids = input_ids[:, -1:]
inputs_dict = {
"input_ids": input_ids,
"past_buckets_states": past,
"use_cache": kwargs["use_cache"],
}
if "num_hashes" in kwargs: if "num_hashes" in kwargs:
inputs_dict["num_hashes"] = kwargs["num_hashes"] inputs_dict["num_hashes"] = kwargs["num_hashes"]
return inputs_dict return inputs_dict
def _reorder_cache(self, past, beam_idx):
reord_past_buckets_states = []
for layer_past in past:
# buckets
if layer_past[0] is not None:
reord_buckets = layer_past[0].index_select(0, beam_idx)
else:
reord_buckets = None
# hidden states
reord_hidden_states = layer_past[1].index_select(0, beam_idx)
reord_past_buckets_states.append((reord_buckets, reord_hidden_states))
return reord_past_buckets_states
@add_start_docstrings("""Reformer Model with a `language modeling` head on top. """, REFORMER_START_DOCSTRING) @add_start_docstrings("""Reformer Model with a `language modeling` head on top. """, REFORMER_START_DOCSTRING)
class ReformerForMaskedLM(ReformerPreTrainedModel): class ReformerForMaskedLM(ReformerPreTrainedModel):
...@@ -1839,6 +2343,7 @@ class ReformerForMaskedLM(ReformerPreTrainedModel): ...@@ -1839,6 +2343,7 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
num_hashes=num_hashes, num_hashes=num_hashes,
use_cache=False, # no causal mask
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
output_attentions=output_attentions, output_attentions=output_attentions,
return_tuple=return_tuple, return_tuple=return_tuple,
...@@ -2027,6 +2532,7 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel): ...@@ -2027,6 +2532,7 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
num_hashes=num_hashes, num_hashes=num_hashes,
use_cache=False, # no causal mask
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
output_attentions=output_attentions, output_attentions=output_attentions,
return_tuple=return_tuple, return_tuple=return_tuple,
......
...@@ -600,7 +600,7 @@ class XLNetModelOutput(ModelOutput): ...@@ -600,7 +600,7 @@ class XLNetModelOutput(ModelOutput):
@dataclass @dataclass
class XLNetLMHeadModelOutput(ModelOutput): class XLNetLMHeadModelOutput(ModelOutput):
""" """
Output type of :class:`~transformers.XLNetModel`. Output type of :class:`~transformers.XLNetLMHeadModel`.
Args: Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided) loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
...@@ -637,7 +637,7 @@ class XLNetLMHeadModelOutput(ModelOutput): ...@@ -637,7 +637,7 @@ class XLNetLMHeadModelOutput(ModelOutput):
@dataclass @dataclass
class XLNetForSequenceClassificationOutput(ModelOutput): class XLNetForSequenceClassificationOutput(ModelOutput):
""" """
Base class for outputs of sentence classification models. Output type of :class:`~transformers.XLNetForSequenceClassification`.
Args: Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
...@@ -671,7 +671,7 @@ class XLNetForSequenceClassificationOutput(ModelOutput): ...@@ -671,7 +671,7 @@ class XLNetForSequenceClassificationOutput(ModelOutput):
@dataclass @dataclass
class XLNetForTokenClassificationOutput(ModelOutput): class XLNetForTokenClassificationOutput(ModelOutput):
""" """
Base class for outputs of token classification models. Output type of :class:`~transformers.XLNetForTokenClassificationOutput`.
Args: Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) : loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
......
...@@ -181,8 +181,8 @@ class ReformerModelTester: ...@@ -181,8 +181,8 @@ class ReformerModelTester:
model = ReformerModel(config=config) model = ReformerModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
(sequence_output,) = model(input_ids, attention_mask=input_mask) sequence_output, _ = model(input_ids, attention_mask=input_mask)
(sequence_output,) = model(input_ids) sequence_output, _ = model(input_ids)
result = { result = {
"sequence_output": sequence_output, "sequence_output": sequence_output,
...@@ -193,17 +193,21 @@ class ReformerModelTester: ...@@ -193,17 +193,21 @@ class ReformerModelTester:
) )
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
model = ReformerModelWithLMHead(config=config) config.is_decoder = False
config.lsh_num_chunks_after = 1
model = ReformerForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0] loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0]
loss.backward() loss.backward()
def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels):
config.lsh_num_chunks_after = 0
config.is_decoder = True
model = ReformerModelWithLMHead(config=config) model = ReformerModelWithLMHead(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) loss, prediction_scores, _ = model(input_ids, attention_mask=input_mask, labels=input_ids)
result = { result = {
"loss": loss, "loss": loss,
"prediction_scores": prediction_scores, "prediction_scores": prediction_scores,
...@@ -332,9 +336,11 @@ class ReformerModelTester: ...@@ -332,9 +336,11 @@ class ReformerModelTester:
config.hidden_dropout_prob = 0 config.hidden_dropout_prob = 0
config.local_attention_probs_dropout_prob = 0 config.local_attention_probs_dropout_prob = 0
config.lsh_attention_probs_dropout_prob = 0 config.lsh_attention_probs_dropout_prob = 0
config.lsh_num_chunks_after = 1
config.is_decoder = False
torch.manual_seed(0) torch.manual_seed(0)
model = ReformerModelWithLMHead(config=config) model = ReformerForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.train() model.train()
model.zero_grad() model.zero_grad()
...@@ -348,7 +354,7 @@ class ReformerModelTester: ...@@ -348,7 +354,7 @@ class ReformerModelTester:
config.chunk_size_feed_forward = 1 config.chunk_size_feed_forward = 1
torch.manual_seed(0) torch.manual_seed(0)
model = ReformerModelWithLMHead(config=config) model = ReformerForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.train() model.train()
model.zero_grad() model.zero_grad()
...@@ -405,7 +411,22 @@ class ReformerModelTester: ...@@ -405,7 +411,22 @@ class ReformerModelTester:
output = model(input_ids, attention_mask=input_mask)[0] output = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertFalse(torch.isnan(output).any().item()) self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_model_generate(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = True
config.lsh_num_chunks_after = 0
config.bos_token_id = 0
config.eos_token_id = None
config.max_length = 20
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
output = model.generate()
self.parent.assertIsNotNone(output)
def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = True
config.lsh_num_chunks_after = 0
model = ReformerModelWithLMHead(config=config) model = ReformerModelWithLMHead(config=config)
model.to(torch_device) model.to(torch_device)
model.half() model.half()
...@@ -418,13 +439,15 @@ class ReformerModelTester: ...@@ -418,13 +439,15 @@ class ReformerModelTester:
# force chunk length to be bigger than input_ids # force chunk length to be bigger than input_ids
config.lsh_attn_chunk_length = 2 * input_ids.shape[-1] config.lsh_attn_chunk_length = 2 * input_ids.shape[-1]
config.local_attn_chunk_length = 2 * input_ids.shape[-1] config.local_attn_chunk_length = 2 * input_ids.shape[-1]
model = ReformerModelWithLMHead(config=config) config.lsh_num_chunks_after = 1
config.is_decoder = False
model = ReformerForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
output_logits = model(input_ids, attention_mask=input_mask)[0] output_logits = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1]) self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
def create_and_check_longformer_for_question_answering(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_for_question_answering(self, config, input_ids, input_mask, choice_labels):
model = ReformerForQuestionAnswering(config=config) model = ReformerForQuestionAnswering(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -440,6 +463,33 @@ class ReformerModelTester: ...@@ -440,6 +463,33 @@ class ReformerModelTester:
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_past_buckets_states(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = True
config.lsh_num_chunks_before = 1
config.lsh_num_chunks_after = 0
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
input_ids_first = input_ids[:, :-1]
input_ids_second = input_ids[:, -1:]
# return saved cache
_, past_buckets_states = model(input_ids_first, use_cache=True)
# calculate last output with and without cache
outputs_with_cache, _ = model(input_ids_second, past_buckets_states=past_buckets_states, use_cache=True)
outputs_without_cache = model(input_ids)[0][:, -1]
# select random slice idx
random_slice_idx = torch.randint(outputs_without_cache.shape[-1], (1, 1), device=torch_device).item()
# outputs should be similar within range
self.parent.assertTrue(
torch.allclose(
outputs_with_cache[:, 0, random_slice_idx], outputs_without_cache[:, random_slice_idx], atol=1e-2
)
)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, choice_labels) = config_and_inputs (config, input_ids, input_mask, choice_labels) = config_and_inputs
...@@ -509,6 +559,18 @@ class ReformerTesterMixin: ...@@ -509,6 +559,18 @@ class ReformerTesterMixin:
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_no_chunking(*config_and_inputs) self.model_tester.create_and_check_reformer_no_chunking(*config_and_inputs)
def test_reformer_qa_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_for_question_answering(*config_and_inputs)
def test_reformer_cached_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_past_buckets_states(*config_and_inputs)
def test_reformer_cached_generate(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_generate(*config_and_inputs)
@slow @slow
def test_dropout_random_seed_is_changing(self): def test_dropout_random_seed_is_changing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
...@@ -621,8 +683,8 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T ...@@ -621,8 +683,8 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
"num_buckets": 2, "num_buckets": 2,
"num_hashes": 4, "num_hashes": 4,
"lsh_attn_chunk_length": 4, "lsh_attn_chunk_length": 4,
"lsh_num_chunks_before": 2, "lsh_num_chunks_before": 1,
"lsh_num_chunks_after": 3, "lsh_num_chunks_after": 0,
"chunk_size_lm_head": 5, "chunk_size_lm_head": 5,
"chunk_size_feed_forward": 6, "chunk_size_feed_forward": 6,
"feed_forward_size": 32, "feed_forward_size": 32,
...@@ -636,7 +698,9 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T ...@@ -636,7 +698,9 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
"axial_pos_embds": True, "axial_pos_embds": True,
"axial_pos_shape": [4, 8], "axial_pos_shape": [4, 8],
"axial_pos_embds_dim": [16, 48], "axial_pos_embds_dim": [16, 48],
"attn_layers": ["lsh", "lsh", "lsh", "lsh"], # sanotheu
# "attn_layers": ["lsh", "lsh", "lsh", "lsh"],
"attn_layers": ["lsh"],
"pad_token_id": 0, "pad_token_id": 0,
"eos_token_id": 2, "eos_token_id": 2,
"scope": None, "scope": None,
...@@ -1049,8 +1113,23 @@ class ReformerIntegrationTests(unittest.TestCase): ...@@ -1049,8 +1113,23 @@ class ReformerIntegrationTests(unittest.TestCase):
output_ids = model.generate( output_ids = model.generate(
input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False, num_hashes=8 input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False, num_hashes=8
) )
output_text = tokenizer.decode(output_ids[0]) output = tokenizer.decode(output_ids[0])
self.assertEqual( self.assertEqual(
output_text, output,
"A few months later state expression in his ideas, at the first entrance. He was positively for an inst", "A few months later state expression in his ideas, at the first entrance. He was positively for an inst",
) )
@slow
def test_pretrained_generate_use_cache_equality(self):
model = ReformerModelWithLMHead.from_pretrained("google/reformer-crime-and-punishment").to(torch_device)
tokenizer = ReformerTokenizer.from_pretrained("google/reformer-crime-and-punishment")
model.eval()
input_ids = tokenizer.encode("A few months later", return_tensors="pt").to(torch_device)
output_ids_with_cache = model.generate(input_ids, max_length=130, num_hashes=8, use_cache=False)
output_ids_without_cache = model.generate(input_ids, max_length=130, num_hashes=8, use_cache=True)
output_with_cache = tokenizer.decode(output_ids_with_cache[0])
output_without_cache = tokenizer.decode(output_ids_without_cache[0])
self.assertEqual(output_with_cache, output_without_cache)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment