Unverified Commit d697b6ca authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Longformer] Major Refactor (#5219)

* refactor naming

* add small slow test

* refactor

* refactor naming

* rename selected to extra

* big global attention refactor

* make style

* refactor naming

* save intermed

* refactor functions

* finish function refactor

* fix tests

* fix longformer

* fix longformer

* fix longformer

* fix all tests but one

* finish longformer

* address sams and izs comments

* fix transpose
parent e0d58ddb
...@@ -25,8 +25,9 @@ from torch.nn import functional as F ...@@ -25,8 +25,9 @@ from torch.nn import functional as F
from .configuration_longformer import LongformerConfig from .configuration_longformer import LongformerConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_bert import BertPreTrainedModel from .modeling_bert import BertIntermediate, BertLayerNorm, BertOutput, BertPooler, BertPreTrainedModel, BertSelfOutput
from .modeling_roberta import RobertaLMHead, RobertaModel from .modeling_roberta import RobertaEmbeddings, RobertaLMHead
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -113,137 +114,171 @@ class LongformerSelfAttention(nn.Module): ...@@ -113,137 +114,171 @@ class LongformerSelfAttention(nn.Module):
attention_window > 0 attention_window > 0
), f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}" ), f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}"
self.one_sided_attention_window_size = attention_window // 2 self.one_sided_attn_window_size = attention_window // 2
@staticmethod @staticmethod
def _skew(x, direction): def _pad_and_transpose_last_two_dims(hidden_states_padded, padding):
"""Convert diagonals into columns (or columns into diagonals depending on `direction`""" """pads rows and then flips rows and columns"""
x_padded = F.pad(x, direction) # padding value is not important because it will be overwritten hidden_states_padded = F.pad(
x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2)) hidden_states_padded, padding
return x_padded ) # padding value is not important because it will be overwritten
hidden_states_padded = hidden_states_padded.view(
*hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2)
)
return hidden_states_padded
@staticmethod @staticmethod
def _skew2(x): def _pad_by_window_overlap_except_last_row(chunked_hidden_states):
"""shift every row 1 step to right converting columns into diagonals""" """shift every row 1 step right, converting columns into diagonals"""
# X = B x C x M x L total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size()
B, C, M, L = x.size() chunked_hidden_states = F.pad(
x = F.pad(x, (0, M + 1)) # B x C x M x (L+M+1). Padding value is not important because it'll be overwritten chunked_hidden_states, (0, window_overlap + 1)
x = x.view(B, C, -1) # B x C x ML+MM+M ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
x = x[:, :, :-M] # B x C x ML+MM chunked_hidden_states = chunked_hidden_states.view(
x = x.view(B, C, M, M + L) # B x C, M x L+M total_num_heads, num_chunks, -1
x = x[:, :, :, :-1] ) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap
return x chunked_hidden_states = chunked_hidden_states[
:, :, :-window_overlap
] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap
chunked_hidden_states = chunked_hidden_states.view(
total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim
) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
return chunked_hidden_states
@staticmethod @staticmethod
def _chunk(x, w): def _chunk(hidden_states, window_overlap):
"""convert into overlapping chunkings. Chunk size = 2w, overlap size = w""" """convert into overlapping chunkings. Chunk size = 2w, overlap size = w"""
# non-overlapping chunks of size = 2w # non-overlapping chunks of size = 2w
x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2)) hidden_states = hidden_states.view(
hidden_states.size(0),
hidden_states.size(1) // (window_overlap * 2),
window_overlap * 2,
hidden_states.size(2),
)
# use `as_strided` to make the chunks overlap with an overlap size = w # use `as_strided` to make the chunks overlap with an overlap size = window_overlap
chunk_size = list(x.size()) chunk_size = list(hidden_states.size())
chunk_size[1] = chunk_size[1] * 2 - 1 chunk_size[1] = chunk_size[1] * 2 - 1
chunk_stride = list(x.stride()) chunk_stride = list(hidden_states.stride())
chunk_stride[1] = chunk_stride[1] // 2 chunk_stride[1] = chunk_stride[1] // 2
return x.as_strided(size=chunk_size, stride=chunk_stride) return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
def _mask_invalid_locations(self, input_tensor, w) -> torch.Tensor: def _mask_invalid_locations(self, input_tensor, affected_seq_len) -> torch.Tensor:
affected_seqlen = w beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
beginning_mask_2d = input_tensor.new_ones(w, w + 1).tril().flip(dims=[0])
beginning_mask = beginning_mask_2d[None, :, None, :] beginning_mask = beginning_mask_2d[None, :, None, :]
ending_mask = beginning_mask.flip(dims=(1, 3)) ending_mask = beginning_mask.flip(dims=(1, 3))
beginning_input = input_tensor[:, :affected_seqlen, :, : w + 1] beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
beginning_mask = beginning_mask.expand(beginning_input.size()) beginning_mask = beginning_mask.expand(beginning_input.size())
beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
ending_input = input_tensor[:, -affected_seqlen:, :, -(w + 1) :] ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
ending_mask = ending_mask.expand(ending_input.size()) ending_mask = ending_mask.expand(ending_input.size())
ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
def _sliding_chunks_matmul_qk(self, q: torch.Tensor, k: torch.Tensor, w: int): def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int):
"""Matrix multiplicatio of query x key tensors using with a sliding window attention pattern. """Matrix multiplication of query and key tensors using with a sliding window attention pattern.
This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer)
with an overlap of size w""" with an overlap of size window_overlap"""
batch_size, seqlen, num_heads, head_dim = q.size() batch_size, seq_len, num_heads, head_dim = query.size()
assert seqlen % (w * 2) == 0, f"Sequence length should be multiple of {w * 2}. Given {seqlen}" assert (
assert q.size() == k.size() seq_len % (window_overlap * 2) == 0
), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}"
assert query.size() == key.size()
chunks_count = seqlen // w - 1 chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
q = q.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim) query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
k = k.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim) key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
chunk_q = self._chunk(q, w) chunked_query = self._chunk(query, window_overlap)
chunk_k = self._chunk(k, w) chunked_key = self._chunk(key, window_overlap)
# matrix multipication # matrix multipication
# bcxd: batch_size * num_heads x chunks x 2w x head_dim # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2w x head_dim # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2w x 2w # bcxy: batch_size * num_heads x chunks x 2window_overlap x window_overlap
chunk_attn = torch.einsum("bcxd,bcyd->bcxy", (chunk_q, chunk_k)) # multiply chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (chunked_query, chunked_key)) # multiply
# convert diagonals into columns # convert diagonals into columns
diagonal_chunk_attn = self._skew(chunk_attn, direction=(0, 0, 0, 1)) diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(
chunked_attention_scores, padding=(0, 0, 0, 1)
)
# allocate space for the overall attention matrix where the chunks are compined. The last dimension # allocate space for the overall attention matrix where the chunks are combined. The last dimension
# has (w * 2 + 1) columns. The first (w) columns are the w lower triangles (attention from a word to # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to
# w previous words). The following column is attention score from each word to itself, then # window_overlap previous words). The following column is attention score from each word to itself, then
# followed by w columns for the upper triangle. # followed by window_overlap columns for the upper triangle.
diagonal_attn = diagonal_chunk_attn.new_empty((batch_size * num_heads, chunks_count + 1, w, w * 2 + 1)) diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1)
)
# copy parts from diagonal_chunk_attn into the compined matrix of attentions # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions
# - copying the main diagonal and the upper triangle # - copying the main diagonal and the upper triangle
diagonal_attn[:, :-1, :, w:] = diagonal_chunk_attn[:, :, :w, : w + 1] diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
diagonal_attn[:, -1, :, w:] = diagonal_chunk_attn[:, -1, w:, : w + 1] :, :, :window_overlap, : window_overlap + 1
]
diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
:, -1, window_overlap:, : window_overlap + 1
]
# - copying the lower triangle # - copying the lower triangle
diagonal_attn[:, 1:, :, :w] = diagonal_chunk_attn[:, :, -(w + 1) : -1, w + 1 :] diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
diagonal_attn[:, 0, 1:w, 1:w] = diagonal_chunk_attn[:, 0, : w - 1, 1 - w :] :, :, -(window_overlap + 1) : -1, window_overlap + 1 :
]
diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
:, 0, : window_overlap - 1, 1 - window_overlap :
]
# separate batch_size and num_heads dimensions again # separate batch_size and num_heads dimensions again
diagonal_attn = diagonal_attn.view(batch_size, num_heads, seqlen, 2 * w + 1).transpose(2, 1) diagonal_attention_scores = diagonal_attention_scores.view(
batch_size, num_heads, seq_len, 2 * window_overlap + 1
self._mask_invalid_locations(diagonal_attn, w) ).transpose(2, 1)
return diagonal_attn
def _sliding_chunks_matmul_pv(self, prob: torch.Tensor, v: torch.Tensor, w: int):
"""Same as _sliding_chunks_matmul_qk but for prob and value tensors. It is expecting the same output
format from _sliding_chunks_matmul_qk"""
batch_size, seqlen, num_heads, head_dim = v.size()
assert seqlen % (w * 2) == 0
assert prob.size()[:3] == v.size()[:3]
assert prob.size(3) == 2 * w + 1
chunks_count = seqlen // w - 1
# group batch_size and num_heads dimensions into one, then chunk seqlen into chunks of size 2w
chunk_prob = prob.transpose(1, 2).reshape(batch_size * num_heads, seqlen // w, w, 2 * w + 1)
# group batch_size and num_heads dimensions into one self._mask_invalid_locations(diagonal_attention_scores, window_overlap)
v = v.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim) return diagonal_attention_scores
# pad seqlen with w at the beginning of the sequence and another w at the end def _sliding_chunks_matmul_attn_probs_value(
padded_v = F.pad(v, (0, 0, w, w), value=-1) self, attn_probs: torch.Tensor, value: torch.Tensor, window_overlap: int
):
"""Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors.
Returned tensor will be of the same shape as `attn_probs`"""
batch_size, seq_len, num_heads, head_dim = value.size()
assert seq_len % (window_overlap * 2) == 0
assert attn_probs.size()[:3] == value.size()[:3]
assert attn_probs.size(3) == 2 * window_overlap + 1
chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1
)
# chunk padded_v into chunks of size 3w and an overlap of size w # group batch_size and num_heads dimensions into one
chunk_v_size = (batch_size * num_heads, chunks_count + 1, 3 * w, head_dim) value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
chunk_v_stride = padded_v.stride()
chunk_v_stride = chunk_v_stride[0], w * chunk_v_stride[1], chunk_v_stride[1], chunk_v_stride[2] # pad seq_len with w at the beginning of the sequence and another window overlap at the end
chunk_v = padded_v.as_strided(size=chunk_v_size, stride=chunk_v_stride) padded_value = F.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim)
chunked_value_stride = padded_value.stride()
chunked_value_stride = (
chunked_value_stride[0],
window_overlap * chunked_value_stride[1],
chunked_value_stride[1],
chunked_value_stride[2],
)
chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride)
skewed_prob = self._skew2(chunk_prob) chunked_attn_probs = self._pad_by_window_overlap_except_last_row(chunked_attn_probs)
context = torch.einsum("bcwd,bcdh->bcwh", (skewed_prob, chunk_v)) context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
return context.view(batch_size, num_heads, seqlen, head_dim).transpose(1, 2) return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
def forward( def forward(
self, self, hidden_states, attention_mask=None, output_attentions=False,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
): ):
""" """
LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`. LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`.
...@@ -254,186 +289,448 @@ class LongformerSelfAttention(nn.Module): ...@@ -254,186 +289,448 @@ class LongformerSelfAttention(nn.Module):
0: local attention 0: local attention
+ve: global attention +ve: global attention
`encoder_hidden_states` and `encoder_attention_mask` are not supported and should be None
""" """
# TODO: add support for `encoder_hidden_states` and `encoder_attention_mask`
assert encoder_hidden_states is None, "`encoder_hidden_states` is not supported and should be None"
assert encoder_attention_mask is None, "`encoder_attention_mask` is not supported and shiould be None"
if attention_mask is not None:
attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)
key_padding_mask = attention_mask < 0
extra_attention_mask = attention_mask > 0 # is index masked or global attention
remove_from_windowed_attention_mask = attention_mask != 0 is_index_masked = attention_mask < 0
is_index_global_attn = attention_mask > 0
num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1) is_global_attn = any(is_index_global_attn.flatten())
max_num_extra_indices_per_batch = num_extra_indices_per_batch.max()
if max_num_extra_indices_per_batch <= 0:
extra_attention_mask = None
else:
# To support the case of variable number of global attention in the rows of a batch,
# we use the following three selection masks to select global attention embeddings
# in a 3d tensor and pad it to `max_num_extra_indices_per_batch`
# 1) selecting embeddings that correspond to global attention
extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True)
zero_to_max_range = torch.arange(
0, max_num_extra_indices_per_batch, device=num_extra_indices_per_batch.device
)
# mask indicating which values are actually going to be padding
selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1)
# 2) location of the non-padding values in the selected global attention
selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True)
# 3) location of the padding values in the selected global attention
selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True)
else:
remove_from_windowed_attention_mask = None
extra_attention_mask = None
key_padding_mask = None
hidden_states = hidden_states.transpose(0, 1) hidden_states = hidden_states.transpose(0, 1)
seqlen, batch_size, embed_dim = hidden_states.size()
assert embed_dim == self.embed_dim # project hidden states
q = self.query(hidden_states) query_vectors = self.query(hidden_states)
k = self.key(hidden_states) key_vectors = self.key(hidden_states)
v = self.value(hidden_states) value_vectors = self.value(hidden_states)
q /= math.sqrt(self.head_dim)
seq_len, batch_size, embed_dim = hidden_states.size()
q = q.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1) assert (
k = k.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1) embed_dim == self.embed_dim
# attn_weights = (batch_size, seqlen, num_heads, window*2+1) ), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}"
attn_weights = self._sliding_chunks_matmul_qk(q, k, self.one_sided_attention_window_size)
if remove_from_windowed_attention_mask is not None: # normalize query
# This implementation is fast and takes very little memory because num_heads x hidden_size = 1 query_vectors /= math.sqrt(self.head_dim)
# from (batch_size x seqlen) to (batch_size x seqlen x num_heads x hidden_size)
remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze( query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
dim=-1 key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
attn_scores = self._sliding_chunks_query_key_matmul(
query_vectors, key_vectors, self.one_sided_attn_window_size
) )
# values to pad for attention probs
remove_from_windowed_attention_mask = (attention_mask != 0).unsqueeze(dim=-1).unsqueeze(dim=-1)
# cast to fp32/fp16 then replace 1's with -inf # cast to fp32/fp16 then replace 1's with -inf
float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill( float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
remove_from_windowed_attention_mask, -10000.0 remove_from_windowed_attention_mask, -10000.0
) )
ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones
# diagonal mask with zeros everywhere and -inf inplace of padding # diagonal mask with zeros everywhere and -inf inplace of padding
d_mask = self._sliding_chunks_matmul_qk(ones, float_mask, self.one_sided_attention_window_size) diagonal_mask = self._sliding_chunks_query_key_matmul(
attn_weights += d_mask float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
assert list(attn_weights.size()) == [ )
# pad local attention probs
attn_scores += diagonal_mask
assert list(attn_scores.size()) == [
batch_size, batch_size,
seqlen, seq_len,
self.num_heads, self.num_heads,
self.one_sided_attention_window_size * 2 + 1, self.one_sided_attn_window_size * 2 + 1,
] ], f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"
# compute local attention probs from global attention keys and contact over window dim
if is_global_attn:
# compute global attn indices required through out forward fn
(
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
) = self._get_global_attn_indices(is_index_global_attn)
# calculate global attn probs from global key
global_key_attn_scores = self._concat_with_global_key_attn_probs(
query_vectors=query_vectors,
key_vectors=key_vectors,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
)
# concat to attn_probs
# (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1)
# the extra attention # free memory
if extra_attention_mask is not None: del global_key_attn_scores
selected_k = k.new_zeros(batch_size, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros]
# (batch_size, seqlen, num_heads, max_num_extra_indices_per_batch)
selected_attn_weights = torch.einsum("blhd,bshd->blhs", (q, selected_k))
selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000.0
# concat to attn_weights
# (batch_size, seqlen, num_heads, extra attention count + 2*window+1)
attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)
attn_weights_fp32 = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
attn_weights = attn_weights_fp32.type_as(attn_weights)
if key_padding_mask is not None:
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_weights = torch.masked_fill(attn_weights, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0)
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
v = v.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
attn = None
if extra_attention_mask is not None:
selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch)
selected_v = v.new_zeros(batch_size, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros]
# use `matmul` because `einsum` crashes sometimes with fp16
# attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2)).transpose(1, 2)
attn_probs = attn_probs.narrow(
-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch
).contiguous()
if attn is None:
attn = self._sliding_chunks_matmul_pv(attn_probs, v, self.one_sided_attention_window_size)
else:
attn += self._sliding_chunks_matmul_pv(attn_probs, v, self.one_sided_attention_window_size)
assert attn.size() == (batch_size, seqlen, self.num_heads, self.head_dim), "Unexpected size" attn_probs_fp32 = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
attn = attn.transpose(0, 1).reshape(seqlen, batch_size, embed_dim).contiguous() attn_probs = attn_probs_fp32.type_as(attn_scores)
# For this case, we'll just recompute the attention for these indices # free memory
# and overwrite the attn tensor. del attn_probs_fp32
# TODO: remove the redundant computation
if extra_attention_mask is not None:
selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, batch_size, embed_dim)
selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[
extra_attention_mask_nonzeros[::-1]
]
q = self.query_global(selected_hidden_states) # softmax sometimes inserts NaN if all positions are masked, replace them with 0
k = self.key_global(hidden_states) attn_probs = torch.masked_fill(attn_probs, is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0)
v = self.value_global(hidden_states)
q /= math.sqrt(self.head_dim) # apply dropout
attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training)
value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
# compute local attention output with global attention value and add
if is_global_attn:
# compute sum of global and local attn
attn_output = self._compute_attn_output_with_global_indices(
value_vectors=value_vectors,
attn_probs=attn_probs,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
)
else:
# compute local attn only
attn_output = self._sliding_chunks_matmul_attn_probs_value(
attn_probs, value_vectors, self.one_sided_attn_window_size
)
q = ( assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size"
q.contiguous() attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
.view(max_num_extra_indices_per_batch, batch_size * self.num_heads, self.head_dim)
.transpose(0, 1)
) # (batch_size * self.num_heads, max_num_extra_indices_per_batch, head_dim)
k = (
k.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
) # batch_size * self.num_heads, seqlen, head_dim)
v = (
v.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
) # batch_size * self.num_heads, seqlen, head_dim)
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [batch_size * self.num_heads, max_num_extra_indices_per_batch, seqlen]
attn_weights = attn_weights.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seqlen)
attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0
if key_padding_mask is not None:
attn_weights = attn_weights.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), -10000.0,)
attn_weights = attn_weights.view(batch_size * self.num_heads, max_num_extra_indices_per_batch, seqlen)
attn_weights_float = F.softmax(
attn_weights, dim=-1, dtype=torch.float32
) # use fp32 for numerical stability
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
selected_attn = torch.bmm(attn_probs, v)
assert list(selected_attn.size()) == [
batch_size * self.num_heads,
max_num_extra_indices_per_batch,
self.head_dim,
]
selected_attn_4d = selected_attn.view( # compute value for global attention and overwrite to attention output
batch_size, self.num_heads, max_num_extra_indices_per_batch, self.head_dim # TODO: remove the redundant computation
if is_global_attn:
global_attn_output = self._compute_global_attn_output_from_hidden(
hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
is_index_masked=is_index_masked,
) )
nonzero_selected_attn = selected_attn_4d[
selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1] # get only non zero global attn output
nonzero_global_attn_output = global_attn_output[
is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1]
] ]
attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view( # overwrite values with global attention
len(selection_padding_mask_nonzeros[0]), -1 attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view(
len(is_local_index_global_attn_nonzero[0]), -1
) )
context_layer = attn.transpose(0, 1) attn_output = attn_output.transpose(0, 1)
if output_attentions: if output_attentions:
if extra_attention_mask is not None: if is_global_attn:
# With global attention, return global attention probabilities only # With global attention, return global attention probabilities only
# batch_size x num_heads x max_num_global_attention_tokens x sequence_length # batch_size x num_heads x max_num_global_attention_tokens x sequence_length
# which is the attention weights from tokens with global attention to all tokens # which is the attention weights from tokens with global attention to all tokens
# It doesn't not return local attention # It doesn't not return local attention
# In case of variable number of global attantion in the rows of a batch, # In case of variable number of global attantion in the rows of a batch,
# attn_weights are padded with -10000.0 attention scores # attn_probs are padded with -10000.0 attention scores
attn_weights = attn_weights.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seqlen) attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
else: else:
# without global attention, return local attention probabilities # without global attention, return local attention probabilities
# batch_size x num_heads x sequence_length x window_size # batch_size x num_heads x sequence_length x window_size
# which is the attention weights of every token attending to its neighbours # which is the attention weights of every token attending to its neighbours
attn_weights = attn_weights.permute(0, 2, 1, 3) attn_probs = attn_probs.permute(0, 2, 1, 3)
outputs = (context_layer, attn_weights) if output_attentions else (context_layer,)
outputs = (attn_output, attn_probs) if output_attentions else (attn_output,)
return outputs return outputs
@staticmethod
def _get_global_attn_indices(is_index_global_attn):
""" compute global attn indices required throughout forward pass """
# helper variable
num_global_attn_indices = is_index_global_attn.long().sum(dim=1)
# max number of global attn indices in batch
max_num_global_attn_indices = num_global_attn_indices.max()
# indices of global attn
is_index_global_attn_nonzero = is_index_global_attn.nonzero(as_tuple=True)
# helper variable
is_local_index_global_attn = torch.arange(
max_num_global_attn_indices, device=is_index_global_attn.device
) < num_global_attn_indices.unsqueeze(dim=-1)
# location of the non-padding values within global attention indices
is_local_index_global_attn_nonzero = is_local_index_global_attn.nonzero(as_tuple=True)
# location of the padding values within global attention indices
is_local_index_no_global_attn_nonzero = (is_local_index_global_attn == 0).nonzero(as_tuple=True)
return (
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
)
def _concat_with_global_key_attn_probs(
self,
key_vectors,
query_vectors,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
):
batch_size = key_vectors.shape[0]
# create only global key vectors
key_vectors_only_global = key_vectors.new_zeros(
batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim
)
key_vectors_only_global[is_local_index_global_attn_nonzero] = key_vectors[is_index_global_attn_nonzero]
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global))
attn_probs_from_global_key[
is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1]
] = -10000.0
return attn_probs_from_global_key
def _compute_attn_output_with_global_indices(
self,
value_vectors,
attn_probs,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
):
batch_size = attn_probs.shape[0]
# cut local attn probs to global only
attn_probs_only_global = attn_probs.narrow(-1, 0, max_num_global_attn_indices)
# get value vectors for global only
value_vectors_only_global = value_vectors.new_zeros(
batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim
)
value_vectors_only_global[is_local_index_global_attn_nonzero] = value_vectors[is_index_global_attn_nonzero]
# use `matmul` because `einsum` crashes sometimes with fp16
# attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
# compute attn output only global
attn_output_only_global = torch.matmul(
attn_probs_only_global.transpose(1, 2), value_vectors_only_global.transpose(1, 2)
).transpose(1, 2)
# reshape attn probs
attn_probs_without_global = attn_probs.narrow(
-1, max_num_global_attn_indices, attn_probs.size(-1) - max_num_global_attn_indices
).contiguous()
# compute attn output with global
attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
attn_probs_without_global, value_vectors, self.one_sided_attn_window_size
)
return attn_output_only_global + attn_output_without_global
def _compute_global_attn_output_from_hidden(
self,
hidden_states,
max_num_global_attn_indices,
is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
is_index_masked,
):
seq_len, batch_size = hidden_states.shape[:2]
# prepare global hidden states
global_attn_hidden_states = hidden_states.new_zeros(max_num_global_attn_indices, batch_size, self.embed_dim)
global_attn_hidden_states[is_local_index_global_attn_nonzero[::-1]] = hidden_states[
is_index_global_attn_nonzero[::-1]
]
# global key, query, value
global_query_vectors_only_global = self.query_global(global_attn_hidden_states)
global_key_vectors = self.key_global(hidden_states)
global_value_vectors = self.value_global(hidden_states)
# normalize
global_query_vectors_only_global /= math.sqrt(self.head_dim)
# reshape
global_query_vectors_only_global = (
global_query_vectors_only_global.contiguous()
.view(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim)
.transpose(0, 1)
) # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim)
global_key_vectors = (
global_key_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
) # batch_size * self.num_heads, seq_len, head_dim)
global_value_vectors = (
global_value_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
) # batch_size * self.num_heads, seq_len, head_dim)
# compute attn scores
global_attn_scores = torch.bmm(global_query_vectors_only_global, global_key_vectors.transpose(1, 2))
assert list(global_attn_scores.size()) == [
batch_size * self.num_heads,
max_num_global_attn_indices,
seq_len,
], f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {global_attn_scores.size()}."
global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
global_attn_scores[
is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], :
] = -10000.0
global_attn_scores = global_attn_scores.masked_fill(is_index_masked.unsqueeze(1).unsqueeze(2), -10000.0,)
global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
# compute global attn probs
global_attn_probs_float = F.softmax(
global_attn_scores, dim=-1, dtype=torch.float32
) # use fp32 for numerical stability
global_attn_probs = F.dropout(
global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
)
# global attn output
global_attn_output = torch.bmm(global_attn_probs, global_value_vectors)
assert list(global_attn_output.size()) == [
batch_size * self.num_heads,
max_num_global_attn_indices,
self.head_dim,
], f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {global_attn_output.size()}."
global_attn_output = global_attn_output.view(
batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim
)
return global_attn_output
class LongformerAttention(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.self = LongformerSelfAttention(config, layer_id)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self, hidden_states, attention_mask=None, output_attentions=False,
):
self_outputs = self.self(hidden_states, attention_mask, output_attentions,)
attn_output = self.output(self_outputs[0], hidden_states)
outputs = (attn_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class LongformerLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.attention = LongformerAttention(config, layer_id)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(
self, hidden_states, attention_mask=None, output_attentions=False,
):
self_attn_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions,)
attn_output = self_attn_outputs[0]
outputs = self_attn_outputs[1:] # add self attentions if we output attention weights
intermediate_output = self.intermediate(attn_output)
layer_output = self.output(intermediate_output, attn_output)
outputs = (layer_output,) + outputs
return outputs
class LongformerEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)])
def forward(
self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False,
):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module), hidden_states, attention_mask,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, output_attentions,)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class LongformerPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained
models.
"""
config_class = LongformerConfig
base_model_prefix = "longformer"
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
LONGFORMER_START_DOCSTRING = r""" LONGFORMER_START_DOCSTRING = r"""
...@@ -498,7 +795,7 @@ LONGFORMER_INPUTS_DOCSTRING = r""" ...@@ -498,7 +795,7 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
"The bare Longformer Model outputting raw hidden-states without any specific head on top.", "The bare Longformer Model outputting raw hidden-states without any specific head on top.",
LONGFORMER_START_DOCSTRING, LONGFORMER_START_DOCSTRING,
) )
class LongformerModel(RobertaModel): class LongformerModel(LongformerPreTrainedModel):
""" """
This class overrides :class:`~transformers.RobertaModel` to provide the ability to process This class overrides :class:`~transformers.RobertaModel` to provide the ability to process
long sequences following the selfattention approach described in `Longformer: the Long-Document Transformer long sequences following the selfattention approach described in `Longformer: the Long-Document Transformer
...@@ -519,6 +816,7 @@ class LongformerModel(RobertaModel): ...@@ -519,6 +816,7 @@ class LongformerModel(RobertaModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.config = config
if isinstance(config.attention_window, int): if isinstance(config.attention_window, int):
assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value" assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value"
...@@ -530,12 +828,26 @@ class LongformerModel(RobertaModel): ...@@ -530,12 +828,26 @@ class LongformerModel(RobertaModel):
f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}" f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}"
) )
for i, layer in enumerate(self.encoder.layer): self.embeddings = RobertaEmbeddings(config)
# replace the `modeling_bert.BertSelfAttention` object with `LongformerSelfAttention` self.encoder = LongformerEncoder(config)
layer.attention.self = LongformerSelfAttention(config, layer_id=i) self.pooler = BertPooler(config)
self.init_weights() self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def _pad_to_window_size( def _pad_to_window_size(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -543,30 +855,29 @@ class LongformerModel(RobertaModel): ...@@ -543,30 +855,29 @@ class LongformerModel(RobertaModel):
token_type_ids: torch.Tensor, token_type_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
attention_window: int,
pad_token_id: int, pad_token_id: int,
): ):
"""A helper function to pad tokens and mask to work with implementation of Longformer selfattention.""" """A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
# padding
attention_window = (
self.config.attention_window
if isinstance(self.config.attention_window, int)
else max(self.config.attention_window)
)
assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}"
input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape
batch_size, seqlen = input_shape[:2] batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seqlen % attention_window) % attention_window padding_len = (attention_window - seq_len % attention_window) % attention_window
if padding_len > 0: if padding_len > 0:
logger.info( logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format( "Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seqlen, seqlen + padding_len, attention_window seq_len, seq_len + padding_len, attention_window
) )
) )
if input_ids is not None: if input_ids is not None:
input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id) input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id)
if attention_mask is not None:
attention_mask = F.pad(
attention_mask, (0, padding_len), value=False
) # no attention on the padding tokens
if token_type_ids is not None:
token_type_ids = F.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0
if position_ids is not None: if position_ids is not None:
# pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings
position_ids = F.pad(position_ids, (0, padding_len), value=pad_token_id) position_ids = F.pad(position_ids, (0, padding_len), value=pad_token_id)
...@@ -577,8 +888,23 @@ class LongformerModel(RobertaModel): ...@@ -577,8 +888,23 @@ class LongformerModel(RobertaModel):
inputs_embeds_padding = self.embeddings(input_ids_padding) inputs_embeds_padding = self.embeddings(input_ids_padding)
inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2)
attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens
token_type_ids = F.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0
return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor):
# longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
# (global_attention_mask + 1) => 1 for local attention, 2 for global attention
# => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
if attention_mask is not None:
attention_mask = attention_mask * (global_attention_mask + 1)
else:
# simply use `global_attention_mask` as `attention_mask`
# if no `attention_mask` is given
attention_mask = global_attention_mask + 1
return attention_mask
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
...@@ -634,24 +960,25 @@ class LongformerModel(RobertaModel): ...@@ -634,24 +960,25 @@ class LongformerModel(RobertaModel):
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
# padding if input_ids is not None and inputs_embeds is not None:
attention_window = ( raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
self.config.attention_window elif input_ids is not None:
if isinstance(self.config.attention_window, int) input_shape = input_ids.size()
else max(self.config.attention_window) elif inputs_embeds is not None:
) input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# merge `global_attention_mask` and `attention_mask` # merge `global_attention_mask` and `attention_mask`
if global_attention_mask is not None: if global_attention_mask is not None:
# longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
# (global_attention_mask + 1) => 1 for local attention, 2 for global attention
# => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
if attention_mask is not None:
attention_mask = attention_mask * (global_attention_mask + 1)
else:
# simply use `global_attention_mask` as `attention_mask`
# if no `attention_mask` is given
attention_mask = global_attention_mask + 1
padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size( padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size(
input_ids=input_ids, input_ids=input_ids,
...@@ -659,23 +986,29 @@ class LongformerModel(RobertaModel): ...@@ -659,23 +986,29 @@ class LongformerModel(RobertaModel):
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
attention_window=attention_window,
pad_token_id=self.config.pad_token_id, pad_token_id=self.config.pad_token_id,
) )
# embed # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
output = super().forward( # ourselves in which case we just need to make it broadcastable to all heads.
input_ids=input_ids, extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
attention_mask=attention_mask,
token_type_ids=token_type_ids, embedding_output = self.embeddings(
position_ids=position_ids, input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
head_mask=None, )
inputs_embeds=inputs_embeds,
encoder_hidden_states=None, encoder_outputs = self.encoder(
encoder_attention_mask=None, embedding_output,
attention_mask=extended_attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
) )
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + encoder_outputs[
1:
] # add hidden_states and attentions if they are here
# undo padding # undo padding
if padding_len > 0: if padding_len > 0:
...@@ -684,13 +1017,13 @@ class LongformerModel(RobertaModel): ...@@ -684,13 +1017,13 @@ class LongformerModel(RobertaModel):
# `pooled_output`: independent of the sequence length # `pooled_output`: independent of the sequence length
# `hidden_states`: mainly used for debugging and analysis, so keep the padding # `hidden_states`: mainly used for debugging and analysis, so keep the padding
# `attentions`: mainly used for debugging and analysis, so keep the padding # `attentions`: mainly used for debugging and analysis, so keep the padding
output = output[0][:, :-padding_len], *output[1:] outputs = outputs[0][:, :-padding_len], *outputs[1:]
return output return outputs
@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING) @add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING)
class LongformerForMaskedLM(BertPreTrainedModel): class LongformerForMaskedLM(LongformerPreTrainedModel):
config_class = LongformerConfig config_class = LongformerConfig
base_model_prefix = "longformer" base_model_prefix = "longformer"
......
...@@ -811,7 +811,7 @@ class ModelTesterMixin: ...@@ -811,7 +811,7 @@ class ModelTesterMixin:
# Wrap model in nn.DataParallel # Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
with torch.no_grad(): with torch.no_grad():
_ = model(**inputs_dict) _ = model(**self._prepare_for_class(inputs_dict, model_class))
global_rng = random.Random() global_rng = random.Random()
......
...@@ -115,6 +115,18 @@ class LongformerModelTester: ...@@ -115,6 +115,18 @@ class LongformerModelTester:
def check_loss_output(self, result): def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_attention_mask_determinism(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LongformerModel(config=config)
model.to(torch_device)
model.eval()
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
output_with_mask = model(input_ids, attention_mask=attention_mask)[0]
output_without_mask = model(input_ids)[0]
self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4))
def create_and_check_longformer_model( def create_and_check_longformer_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
...@@ -134,6 +146,36 @@ class LongformerModelTester: ...@@ -134,6 +146,36 @@ class LongformerModelTester:
) )
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_longformer_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LongformerModel(config=config)
model.to(torch_device)
model.eval()
global_attention_mask = input_mask.clone()
global_attention_mask[:, input_mask.shape[-1] // 2] = 0
global_attention_mask = global_attention_mask.to(torch_device)
sequence_output, pooled_output = model(
input_ids,
attention_mask=input_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
)
sequence_output, pooled_output = model(
input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask
)
sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_longformer_for_masked_lm( def create_and_check_longformer_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
...@@ -243,7 +285,13 @@ class LongformerModelTester: ...@@ -243,7 +285,13 @@ class LongformerModelTester:
token_labels, token_labels,
choice_labels, choice_labels,
) = config_and_inputs ) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} global_attention_mask = torch.zeros_like(input_ids)
inputs_dict = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": input_mask,
"global_attention_mask": global_attention_mask,
}
return config, inputs_dict return config, inputs_dict
def prepare_config_and_inputs_for_question_answering(self): def prepare_config_and_inputs_for_question_answering(self):
...@@ -277,11 +325,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -277,11 +325,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
( (
LongformerModel, LongformerModel,
LongformerForMaskedLM, LongformerForMaskedLM,
# TODO: make tests pass for those models LongformerForSequenceClassification,
# LongformerForSequenceClassification, LongformerForQuestionAnswering,
# LongformerForQuestionAnswering, LongformerForTokenClassification,
# LongformerForTokenClassification, LongformerForMultipleChoice,
# LongformerForMultipleChoice,
) )
if is_torch_available() if is_torch_available()
else () else ()
...@@ -298,6 +345,14 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -298,6 +345,14 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model(*config_and_inputs) self.model_tester.create_and_check_longformer_model(*config_and_inputs)
def test_longformer_model_attention_mask_determinism(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
def test_longformer_model_global_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs)
def test_longformer_for_masked_lm(self): def test_longformer_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
...@@ -325,15 +380,31 @@ class LongformerModelIntegrationTest(unittest.TestCase): ...@@ -325,15 +380,31 @@ class LongformerModelIntegrationTest(unittest.TestCase):
model = LongformerModel.from_pretrained("allenai/longformer-base-4096") model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
model.to(torch_device) model.to(torch_device)
# 'Hello world!'
input_ids = torch.tensor([[0, 20920, 232, 328, 1437, 2]], dtype=torch.long, device=torch_device)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
output = model(input_ids, attention_mask=attention_mask)[0]
output_without_mask = model(input_ids)[0]
expected_output_slice = torch.tensor([0.0549, 0.1087, -0.1119, -0.0368, 0.0250], device=torch_device)
self.assertTrue(torch.allclose(output[0, 0, -5:], expected_output_slice, atol=1e-4))
self.assertTrue(torch.allclose(output_without_mask[0, 0, -5:], expected_output_slice, atol=1e-4))
@slow
def test_inference_no_head_long(self):
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
model.to(torch_device)
# 'Hello world! ' repeated 1000 times # 'Hello world! ' repeated 1000 times
input_ids = torch.tensor( input_ids = torch.tensor(
[[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=torch.long, device=torch_device [[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=torch.long, device=torch_device
) # long input ) # long input
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
attention_mask[:, [1, 4, 21]] = 2 # Set global attention on a few random positions global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=input_ids.device)
global_attention_mask[:, [1, 4, 21]] = 1 # Set global attention on a few random positions
output = model(input_ids, attention_mask=attention_mask)[0] output = model(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)[0]
expected_output_sum = torch.tensor(74585.8594, device=torch_device) expected_output_sum = torch.tensor(74585.8594, device=torch_device)
expected_output_mean = torch.tensor(0.0243, device=torch_device) expected_output_mean = torch.tensor(0.0243, device=torch_device)
...@@ -341,7 +412,7 @@ class LongformerModelIntegrationTest(unittest.TestCase): ...@@ -341,7 +412,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4)) self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4))
@slow @slow
def test_inference_masked_lm(self): def test_inference_masked_lm_long(self):
model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096") model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
model.to(torch_device) model.to(torch_device)
...@@ -352,9 +423,9 @@ class LongformerModelIntegrationTest(unittest.TestCase): ...@@ -352,9 +423,9 @@ class LongformerModelIntegrationTest(unittest.TestCase):
loss, prediction_scores = model(input_ids, labels=input_ids) loss, prediction_scores = model(input_ids, labels=input_ids)
expected_loss = torch.tensor(0.0620, device=torch_device) expected_loss = torch.tensor(0.0074, device=torch_device)
expected_prediction_scores_sum = torch.tensor(-6.1599e08, device=torch_device) expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device)
expected_prediction_scores_mean = torch.tensor(-3.0622, device=torch_device) expected_prediction_scores_mean = torch.tensor(-3.0348, device=torch_device)
input_ids = input_ids.to(torch_device) input_ids = input_ids.to(torch_device)
self.assertTrue(torch.allclose(loss, expected_loss, atol=1e-4)) self.assertTrue(torch.allclose(loss, expected_loss, atol=1e-4))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment