Unverified Commit d697b6ca authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Longformer] Major Refactor (#5219)

* refactor naming

* add small slow test

* refactor

* refactor naming

* rename selected to extra

* big global attention refactor

* make style

* refactor naming

* save intermed

* refactor functions

* finish function refactor

* fix tests

* fix longformer

* fix longformer

* fix longformer

* fix all tests but one

* finish longformer

* address sams and izs comments

* fix transpose
parent e0d58ddb
......@@ -25,8 +25,9 @@ from torch.nn import functional as F
from .configuration_longformer import LongformerConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_bert import BertPreTrainedModel
from .modeling_roberta import RobertaLMHead, RobertaModel
from .modeling_bert import BertIntermediate, BertLayerNorm, BertOutput, BertPooler, BertPreTrainedModel, BertSelfOutput
from .modeling_roberta import RobertaEmbeddings, RobertaLMHead
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
logger = logging.getLogger(__name__)
......@@ -113,137 +114,171 @@ class LongformerSelfAttention(nn.Module):
attention_window > 0
), f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}"
self.one_sided_attention_window_size = attention_window // 2
self.one_sided_attn_window_size = attention_window // 2
@staticmethod
def _skew(x, direction):
"""Convert diagonals into columns (or columns into diagonals depending on `direction`"""
x_padded = F.pad(x, direction) # padding value is not important because it will be overwritten
x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2))
return x_padded
def _pad_and_transpose_last_two_dims(hidden_states_padded, padding):
"""pads rows and then flips rows and columns"""
hidden_states_padded = F.pad(
hidden_states_padded, padding
) # padding value is not important because it will be overwritten
hidden_states_padded = hidden_states_padded.view(
*hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2)
)
return hidden_states_padded
@staticmethod
def _skew2(x):
"""shift every row 1 step to right converting columns into diagonals"""
# X = B x C x M x L
B, C, M, L = x.size()
x = F.pad(x, (0, M + 1)) # B x C x M x (L+M+1). Padding value is not important because it'll be overwritten
x = x.view(B, C, -1) # B x C x ML+MM+M
x = x[:, :, :-M] # B x C x ML+MM
x = x.view(B, C, M, M + L) # B x C, M x L+M
x = x[:, :, :, :-1]
return x
def _pad_by_window_overlap_except_last_row(chunked_hidden_states):
"""shift every row 1 step right, converting columns into diagonals"""
total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size()
chunked_hidden_states = F.pad(
chunked_hidden_states, (0, window_overlap + 1)
) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
chunked_hidden_states = chunked_hidden_states.view(
total_num_heads, num_chunks, -1
) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap
chunked_hidden_states = chunked_hidden_states[
:, :, :-window_overlap
] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap
chunked_hidden_states = chunked_hidden_states.view(
total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim
) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
return chunked_hidden_states
@staticmethod
def _chunk(x, w):
def _chunk(hidden_states, window_overlap):
"""convert into overlapping chunkings. Chunk size = 2w, overlap size = w"""
# non-overlapping chunks of size = 2w
x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2))
hidden_states = hidden_states.view(
hidden_states.size(0),
hidden_states.size(1) // (window_overlap * 2),
window_overlap * 2,
hidden_states.size(2),
)
# use `as_strided` to make the chunks overlap with an overlap size = w
chunk_size = list(x.size())
# use `as_strided` to make the chunks overlap with an overlap size = window_overlap
chunk_size = list(hidden_states.size())
chunk_size[1] = chunk_size[1] * 2 - 1
chunk_stride = list(x.stride())
chunk_stride = list(hidden_states.stride())
chunk_stride[1] = chunk_stride[1] // 2
return x.as_strided(size=chunk_size, stride=chunk_stride)
return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
def _mask_invalid_locations(self, input_tensor, w) -> torch.Tensor:
affected_seqlen = w
beginning_mask_2d = input_tensor.new_ones(w, w + 1).tril().flip(dims=[0])
def _mask_invalid_locations(self, input_tensor, affected_seq_len) -> torch.Tensor:
beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
beginning_mask = beginning_mask_2d[None, :, None, :]
ending_mask = beginning_mask.flip(dims=(1, 3))
beginning_input = input_tensor[:, :affected_seqlen, :, : w + 1]
beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
beginning_mask = beginning_mask.expand(beginning_input.size())
beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
ending_input = input_tensor[:, -affected_seqlen:, :, -(w + 1) :]
ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
ending_mask = ending_mask.expand(ending_input.size())
ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
def _sliding_chunks_matmul_qk(self, q: torch.Tensor, k: torch.Tensor, w: int):
"""Matrix multiplicatio of query x key tensors using with a sliding window attention pattern.
def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int):
"""Matrix multiplication of query and key tensors using with a sliding window attention pattern.
This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer)
with an overlap of size w"""
batch_size, seqlen, num_heads, head_dim = q.size()
assert seqlen % (w * 2) == 0, f"Sequence length should be multiple of {w * 2}. Given {seqlen}"
assert q.size() == k.size()
with an overlap of size window_overlap"""
batch_size, seq_len, num_heads, head_dim = query.size()
assert (
seq_len % (window_overlap * 2) == 0
), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}"
assert query.size() == key.size()
chunks_count = seqlen // w - 1
chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2
q = q.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim)
k = k.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim)
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
chunk_q = self._chunk(q, w)
chunk_k = self._chunk(k, w)
chunked_query = self._chunk(query, window_overlap)
chunked_key = self._chunk(key, window_overlap)
# matrix multipication
# bcxd: batch_size * num_heads x chunks x 2w x head_dim
# bcyd: batch_size * num_heads x chunks x 2w x head_dim
# bcxy: batch_size * num_heads x chunks x 2w x 2w
chunk_attn = torch.einsum("bcxd,bcyd->bcxy", (chunk_q, chunk_k)) # multiply
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x window_overlap
chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (chunked_query, chunked_key)) # multiply
# convert diagonals into columns
diagonal_chunk_attn = self._skew(chunk_attn, direction=(0, 0, 0, 1))
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(
chunked_attention_scores, padding=(0, 0, 0, 1)
)
# allocate space for the overall attention matrix where the chunks are compined. The last dimension
# has (w * 2 + 1) columns. The first (w) columns are the w lower triangles (attention from a word to
# w previous words). The following column is attention score from each word to itself, then
# followed by w columns for the upper triangle.
# allocate space for the overall attention matrix where the chunks are combined. The last dimension
# has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to
# window_overlap previous words). The following column is attention score from each word to itself, then
# followed by window_overlap columns for the upper triangle.
diagonal_attn = diagonal_chunk_attn.new_empty((batch_size * num_heads, chunks_count + 1, w, w * 2 + 1))
diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1)
)
# copy parts from diagonal_chunk_attn into the compined matrix of attentions
# copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions
# - copying the main diagonal and the upper triangle
diagonal_attn[:, :-1, :, w:] = diagonal_chunk_attn[:, :, :w, : w + 1]
diagonal_attn[:, -1, :, w:] = diagonal_chunk_attn[:, -1, w:, : w + 1]
diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
:, :, :window_overlap, : window_overlap + 1
]
diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
:, -1, window_overlap:, : window_overlap + 1
]
# - copying the lower triangle
diagonal_attn[:, 1:, :, :w] = diagonal_chunk_attn[:, :, -(w + 1) : -1, w + 1 :]
diagonal_attn[:, 0, 1:w, 1:w] = diagonal_chunk_attn[:, 0, : w - 1, 1 - w :]
diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
:, :, -(window_overlap + 1) : -1, window_overlap + 1 :
]
diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
:, 0, : window_overlap - 1, 1 - window_overlap :
]
# separate batch_size and num_heads dimensions again
diagonal_attn = diagonal_attn.view(batch_size, num_heads, seqlen, 2 * w + 1).transpose(2, 1)
self._mask_invalid_locations(diagonal_attn, w)
return diagonal_attn
def _sliding_chunks_matmul_pv(self, prob: torch.Tensor, v: torch.Tensor, w: int):
"""Same as _sliding_chunks_matmul_qk but for prob and value tensors. It is expecting the same output
format from _sliding_chunks_matmul_qk"""
batch_size, seqlen, num_heads, head_dim = v.size()
assert seqlen % (w * 2) == 0
assert prob.size()[:3] == v.size()[:3]
assert prob.size(3) == 2 * w + 1
chunks_count = seqlen // w - 1
# group batch_size and num_heads dimensions into one, then chunk seqlen into chunks of size 2w
chunk_prob = prob.transpose(1, 2).reshape(batch_size * num_heads, seqlen // w, w, 2 * w + 1)
diagonal_attention_scores = diagonal_attention_scores.view(
batch_size, num_heads, seq_len, 2 * window_overlap + 1
).transpose(2, 1)
# group batch_size and num_heads dimensions into one
v = v.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim)
self._mask_invalid_locations(diagonal_attention_scores, window_overlap)
return diagonal_attention_scores
# pad seqlen with w at the beginning of the sequence and another w at the end
padded_v = F.pad(v, (0, 0, w, w), value=-1)
def _sliding_chunks_matmul_attn_probs_value(
self, attn_probs: torch.Tensor, value: torch.Tensor, window_overlap: int
):
"""Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors.
Returned tensor will be of the same shape as `attn_probs`"""
batch_size, seq_len, num_heads, head_dim = value.size()
assert seq_len % (window_overlap * 2) == 0
assert attn_probs.size()[:3] == value.size()[:3]
assert attn_probs.size(3) == 2 * window_overlap + 1
chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1
)
# chunk padded_v into chunks of size 3w and an overlap of size w
chunk_v_size = (batch_size * num_heads, chunks_count + 1, 3 * w, head_dim)
chunk_v_stride = padded_v.stride()
chunk_v_stride = chunk_v_stride[0], w * chunk_v_stride[1], chunk_v_stride[1], chunk_v_stride[2]
chunk_v = padded_v.as_strided(size=chunk_v_size, stride=chunk_v_stride)
# group batch_size and num_heads dimensions into one
value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
padded_value = F.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim)
chunked_value_stride = padded_value.stride()
chunked_value_stride = (
chunked_value_stride[0],
window_overlap * chunked_value_stride[1],
chunked_value_stride[1],
chunked_value_stride[2],
)
chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride)
skewed_prob = self._skew2(chunk_prob)
chunked_attn_probs = self._pad_by_window_overlap_except_last_row(chunked_attn_probs)
context = torch.einsum("bcwd,bcdh->bcwh", (skewed_prob, chunk_v))
return context.view(batch_size, num_heads, seqlen, head_dim).transpose(1, 2)
context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
self, hidden_states, attention_mask=None, output_attentions=False,
):
"""
LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`.
......@@ -254,186 +289,448 @@ class LongformerSelfAttention(nn.Module):
0: local attention
+ve: global attention
`encoder_hidden_states` and `encoder_attention_mask` are not supported and should be None
"""
# TODO: add support for `encoder_hidden_states` and `encoder_attention_mask`
assert encoder_hidden_states is None, "`encoder_hidden_states` is not supported and should be None"
assert encoder_attention_mask is None, "`encoder_attention_mask` is not supported and shiould be None"
if attention_mask is not None:
attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)
key_padding_mask = attention_mask < 0
extra_attention_mask = attention_mask > 0
remove_from_windowed_attention_mask = attention_mask != 0
num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1)
max_num_extra_indices_per_batch = num_extra_indices_per_batch.max()
if max_num_extra_indices_per_batch <= 0:
extra_attention_mask = None
else:
# To support the case of variable number of global attention in the rows of a batch,
# we use the following three selection masks to select global attention embeddings
# in a 3d tensor and pad it to `max_num_extra_indices_per_batch`
# 1) selecting embeddings that correspond to global attention
extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True)
zero_to_max_range = torch.arange(
0, max_num_extra_indices_per_batch, device=num_extra_indices_per_batch.device
)
# mask indicating which values are actually going to be padding
selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1)
# 2) location of the non-padding values in the selected global attention
selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True)
# 3) location of the padding values in the selected global attention
selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True)
else:
remove_from_windowed_attention_mask = None
extra_attention_mask = None
key_padding_mask = None
# is index masked or global attention
is_index_masked = attention_mask < 0
is_index_global_attn = attention_mask > 0
is_global_attn = any(is_index_global_attn.flatten())
hidden_states = hidden_states.transpose(0, 1)
seqlen, batch_size, embed_dim = hidden_states.size()
assert embed_dim == self.embed_dim
q = self.query(hidden_states)
k = self.key(hidden_states)
v = self.value(hidden_states)
q /= math.sqrt(self.head_dim)
q = q.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
k = k.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
# attn_weights = (batch_size, seqlen, num_heads, window*2+1)
attn_weights = self._sliding_chunks_matmul_qk(q, k, self.one_sided_attention_window_size)
if remove_from_windowed_attention_mask is not None:
# This implementation is fast and takes very little memory because num_heads x hidden_size = 1
# from (batch_size x seqlen) to (batch_size x seqlen x num_heads x hidden_size)
remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze(
dim=-1
# project hidden states
query_vectors = self.query(hidden_states)
key_vectors = self.key(hidden_states)
value_vectors = self.value(hidden_states)
seq_len, batch_size, embed_dim = hidden_states.size()
assert (
embed_dim == self.embed_dim
), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}"
# normalize query
query_vectors /= math.sqrt(self.head_dim)
query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
attn_scores = self._sliding_chunks_query_key_matmul(
query_vectors, key_vectors, self.one_sided_attn_window_size
)
# values to pad for attention probs
remove_from_windowed_attention_mask = (attention_mask != 0).unsqueeze(dim=-1).unsqueeze(dim=-1)
# cast to fp32/fp16 then replace 1's with -inf
float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill(
float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
remove_from_windowed_attention_mask, -10000.0
)
ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones
# diagonal mask with zeros everywhere and -inf inplace of padding
d_mask = self._sliding_chunks_matmul_qk(ones, float_mask, self.one_sided_attention_window_size)
attn_weights += d_mask
assert list(attn_weights.size()) == [
diagonal_mask = self._sliding_chunks_query_key_matmul(
float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
)
# pad local attention probs
attn_scores += diagonal_mask
assert list(attn_scores.size()) == [
batch_size,
seqlen,
seq_len,
self.num_heads,
self.one_sided_attention_window_size * 2 + 1,
]
self.one_sided_attn_window_size * 2 + 1,
], f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"
# compute local attention probs from global attention keys and contact over window dim
if is_global_attn:
# compute global attn indices required through out forward fn
(
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
) = self._get_global_attn_indices(is_index_global_attn)
# calculate global attn probs from global key
global_key_attn_scores = self._concat_with_global_key_attn_probs(
query_vectors=query_vectors,
key_vectors=key_vectors,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
)
# concat to attn_probs
# (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1)
# the extra attention
if extra_attention_mask is not None:
selected_k = k.new_zeros(batch_size, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros]
# (batch_size, seqlen, num_heads, max_num_extra_indices_per_batch)
selected_attn_weights = torch.einsum("blhd,bshd->blhs", (q, selected_k))
selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000.0
# concat to attn_weights
# (batch_size, seqlen, num_heads, extra attention count + 2*window+1)
attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)
attn_weights_fp32 = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
attn_weights = attn_weights_fp32.type_as(attn_weights)
if key_padding_mask is not None:
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_weights = torch.masked_fill(attn_weights, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0)
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
v = v.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
attn = None
if extra_attention_mask is not None:
selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch)
selected_v = v.new_zeros(batch_size, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros]
# use `matmul` because `einsum` crashes sometimes with fp16
# attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2)).transpose(1, 2)
attn_probs = attn_probs.narrow(
-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch
).contiguous()
if attn is None:
attn = self._sliding_chunks_matmul_pv(attn_probs, v, self.one_sided_attention_window_size)
else:
attn += self._sliding_chunks_matmul_pv(attn_probs, v, self.one_sided_attention_window_size)
# free memory
del global_key_attn_scores
assert attn.size() == (batch_size, seqlen, self.num_heads, self.head_dim), "Unexpected size"
attn = attn.transpose(0, 1).reshape(seqlen, batch_size, embed_dim).contiguous()
attn_probs_fp32 = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
attn_probs = attn_probs_fp32.type_as(attn_scores)
# For this case, we'll just recompute the attention for these indices
# and overwrite the attn tensor.
# TODO: remove the redundant computation
if extra_attention_mask is not None:
selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, batch_size, embed_dim)
selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[
extra_attention_mask_nonzeros[::-1]
]
# free memory
del attn_probs_fp32
q = self.query_global(selected_hidden_states)
k = self.key_global(hidden_states)
v = self.value_global(hidden_states)
q /= math.sqrt(self.head_dim)
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_probs = torch.masked_fill(attn_probs, is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0)
# apply dropout
attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training)
value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
# compute local attention output with global attention value and add
if is_global_attn:
# compute sum of global and local attn
attn_output = self._compute_attn_output_with_global_indices(
value_vectors=value_vectors,
attn_probs=attn_probs,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
)
else:
# compute local attn only
attn_output = self._sliding_chunks_matmul_attn_probs_value(
attn_probs, value_vectors, self.one_sided_attn_window_size
)
q = (
q.contiguous()
.view(max_num_extra_indices_per_batch, batch_size * self.num_heads, self.head_dim)
.transpose(0, 1)
) # (batch_size * self.num_heads, max_num_extra_indices_per_batch, head_dim)
k = (
k.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
) # batch_size * self.num_heads, seqlen, head_dim)
v = (
v.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
) # batch_size * self.num_heads, seqlen, head_dim)
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [batch_size * self.num_heads, max_num_extra_indices_per_batch, seqlen]
attn_weights = attn_weights.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seqlen)
attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0
if key_padding_mask is not None:
attn_weights = attn_weights.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), -10000.0,)
attn_weights = attn_weights.view(batch_size * self.num_heads, max_num_extra_indices_per_batch, seqlen)
attn_weights_float = F.softmax(
attn_weights, dim=-1, dtype=torch.float32
) # use fp32 for numerical stability
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
selected_attn = torch.bmm(attn_probs, v)
assert list(selected_attn.size()) == [
batch_size * self.num_heads,
max_num_extra_indices_per_batch,
self.head_dim,
]
assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size"
attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
selected_attn_4d = selected_attn.view(
batch_size, self.num_heads, max_num_extra_indices_per_batch, self.head_dim
# compute value for global attention and overwrite to attention output
# TODO: remove the redundant computation
if is_global_attn:
global_attn_output = self._compute_global_attn_output_from_hidden(
hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
is_index_masked=is_index_masked,
)
nonzero_selected_attn = selected_attn_4d[
selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1]
# get only non zero global attn output
nonzero_global_attn_output = global_attn_output[
is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1]
]
attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(
len(selection_padding_mask_nonzeros[0]), -1
# overwrite values with global attention
attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view(
len(is_local_index_global_attn_nonzero[0]), -1
)
context_layer = attn.transpose(0, 1)
attn_output = attn_output.transpose(0, 1)
if output_attentions:
if extra_attention_mask is not None:
if is_global_attn:
# With global attention, return global attention probabilities only
# batch_size x num_heads x max_num_global_attention_tokens x sequence_length
# which is the attention weights from tokens with global attention to all tokens
# It doesn't not return local attention
# In case of variable number of global attantion in the rows of a batch,
# attn_weights are padded with -10000.0 attention scores
attn_weights = attn_weights.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seqlen)
# attn_probs are padded with -10000.0 attention scores
attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
else:
# without global attention, return local attention probabilities
# batch_size x num_heads x sequence_length x window_size
# which is the attention weights of every token attending to its neighbours
attn_weights = attn_weights.permute(0, 2, 1, 3)
outputs = (context_layer, attn_weights) if output_attentions else (context_layer,)
attn_probs = attn_probs.permute(0, 2, 1, 3)
outputs = (attn_output, attn_probs) if output_attentions else (attn_output,)
return outputs
@staticmethod
def _get_global_attn_indices(is_index_global_attn):
""" compute global attn indices required throughout forward pass """
# helper variable
num_global_attn_indices = is_index_global_attn.long().sum(dim=1)
# max number of global attn indices in batch
max_num_global_attn_indices = num_global_attn_indices.max()
# indices of global attn
is_index_global_attn_nonzero = is_index_global_attn.nonzero(as_tuple=True)
# helper variable
is_local_index_global_attn = torch.arange(
max_num_global_attn_indices, device=is_index_global_attn.device
) < num_global_attn_indices.unsqueeze(dim=-1)
# location of the non-padding values within global attention indices
is_local_index_global_attn_nonzero = is_local_index_global_attn.nonzero(as_tuple=True)
# location of the padding values within global attention indices
is_local_index_no_global_attn_nonzero = (is_local_index_global_attn == 0).nonzero(as_tuple=True)
return (
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
)
def _concat_with_global_key_attn_probs(
self,
key_vectors,
query_vectors,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
):
batch_size = key_vectors.shape[0]
# create only global key vectors
key_vectors_only_global = key_vectors.new_zeros(
batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim
)
key_vectors_only_global[is_local_index_global_attn_nonzero] = key_vectors[is_index_global_attn_nonzero]
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global))
attn_probs_from_global_key[
is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1]
] = -10000.0
return attn_probs_from_global_key
def _compute_attn_output_with_global_indices(
self,
value_vectors,
attn_probs,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
):
batch_size = attn_probs.shape[0]
# cut local attn probs to global only
attn_probs_only_global = attn_probs.narrow(-1, 0, max_num_global_attn_indices)
# get value vectors for global only
value_vectors_only_global = value_vectors.new_zeros(
batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim
)
value_vectors_only_global[is_local_index_global_attn_nonzero] = value_vectors[is_index_global_attn_nonzero]
# use `matmul` because `einsum` crashes sometimes with fp16
# attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
# compute attn output only global
attn_output_only_global = torch.matmul(
attn_probs_only_global.transpose(1, 2), value_vectors_only_global.transpose(1, 2)
).transpose(1, 2)
# reshape attn probs
attn_probs_without_global = attn_probs.narrow(
-1, max_num_global_attn_indices, attn_probs.size(-1) - max_num_global_attn_indices
).contiguous()
# compute attn output with global
attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
attn_probs_without_global, value_vectors, self.one_sided_attn_window_size
)
return attn_output_only_global + attn_output_without_global
def _compute_global_attn_output_from_hidden(
self,
hidden_states,
max_num_global_attn_indices,
is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
is_index_masked,
):
seq_len, batch_size = hidden_states.shape[:2]
# prepare global hidden states
global_attn_hidden_states = hidden_states.new_zeros(max_num_global_attn_indices, batch_size, self.embed_dim)
global_attn_hidden_states[is_local_index_global_attn_nonzero[::-1]] = hidden_states[
is_index_global_attn_nonzero[::-1]
]
# global key, query, value
global_query_vectors_only_global = self.query_global(global_attn_hidden_states)
global_key_vectors = self.key_global(hidden_states)
global_value_vectors = self.value_global(hidden_states)
# normalize
global_query_vectors_only_global /= math.sqrt(self.head_dim)
# reshape
global_query_vectors_only_global = (
global_query_vectors_only_global.contiguous()
.view(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim)
.transpose(0, 1)
) # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim)
global_key_vectors = (
global_key_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
) # batch_size * self.num_heads, seq_len, head_dim)
global_value_vectors = (
global_value_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
) # batch_size * self.num_heads, seq_len, head_dim)
# compute attn scores
global_attn_scores = torch.bmm(global_query_vectors_only_global, global_key_vectors.transpose(1, 2))
assert list(global_attn_scores.size()) == [
batch_size * self.num_heads,
max_num_global_attn_indices,
seq_len,
], f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {global_attn_scores.size()}."
global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
global_attn_scores[
is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], :
] = -10000.0
global_attn_scores = global_attn_scores.masked_fill(is_index_masked.unsqueeze(1).unsqueeze(2), -10000.0,)
global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
# compute global attn probs
global_attn_probs_float = F.softmax(
global_attn_scores, dim=-1, dtype=torch.float32
) # use fp32 for numerical stability
global_attn_probs = F.dropout(
global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
)
# global attn output
global_attn_output = torch.bmm(global_attn_probs, global_value_vectors)
assert list(global_attn_output.size()) == [
batch_size * self.num_heads,
max_num_global_attn_indices,
self.head_dim,
], f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {global_attn_output.size()}."
global_attn_output = global_attn_output.view(
batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim
)
return global_attn_output
class LongformerAttention(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.self = LongformerSelfAttention(config, layer_id)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self, hidden_states, attention_mask=None, output_attentions=False,
):
self_outputs = self.self(hidden_states, attention_mask, output_attentions,)
attn_output = self.output(self_outputs[0], hidden_states)
outputs = (attn_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class LongformerLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.attention = LongformerAttention(config, layer_id)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(
self, hidden_states, attention_mask=None, output_attentions=False,
):
self_attn_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions,)
attn_output = self_attn_outputs[0]
outputs = self_attn_outputs[1:] # add self attentions if we output attention weights
intermediate_output = self.intermediate(attn_output)
layer_output = self.output(intermediate_output, attn_output)
outputs = (layer_output,) + outputs
return outputs
class LongformerEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)])
def forward(
self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False,
):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module), hidden_states, attention_mask,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, output_attentions,)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class LongformerPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained
models.
"""
config_class = LongformerConfig
base_model_prefix = "longformer"
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
LONGFORMER_START_DOCSTRING = r"""
......@@ -498,7 +795,7 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
"The bare Longformer Model outputting raw hidden-states without any specific head on top.",
LONGFORMER_START_DOCSTRING,
)
class LongformerModel(RobertaModel):
class LongformerModel(LongformerPreTrainedModel):
"""
This class overrides :class:`~transformers.RobertaModel` to provide the ability to process
long sequences following the selfattention approach described in `Longformer: the Long-Document Transformer
......@@ -519,6 +816,7 @@ class LongformerModel(RobertaModel):
def __init__(self, config):
super().__init__(config)
self.config = config
if isinstance(config.attention_window, int):
assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value"
......@@ -530,12 +828,26 @@ class LongformerModel(RobertaModel):
f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}"
)
for i, layer in enumerate(self.encoder.layer):
# replace the `modeling_bert.BertSelfAttention` object with `LongformerSelfAttention`
layer.attention.self = LongformerSelfAttention(config, layer_id=i)
self.embeddings = RobertaEmbeddings(config)
self.encoder = LongformerEncoder(config)
self.pooler = BertPooler(config)
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def _pad_to_window_size(
self,
input_ids: torch.Tensor,
......@@ -543,30 +855,29 @@ class LongformerModel(RobertaModel):
token_type_ids: torch.Tensor,
position_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
attention_window: int,
pad_token_id: int,
):
"""A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
# padding
attention_window = (
self.config.attention_window
if isinstance(self.config.attention_window, int)
else max(self.config.attention_window)
)
assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}"
input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape
batch_size, seqlen = input_shape[:2]
batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seqlen % attention_window) % attention_window
padding_len = (attention_window - seq_len % attention_window) % attention_window
if padding_len > 0:
logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seqlen, seqlen + padding_len, attention_window
seq_len, seq_len + padding_len, attention_window
)
)
if input_ids is not None:
input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id)
if attention_mask is not None:
attention_mask = F.pad(
attention_mask, (0, padding_len), value=False
) # no attention on the padding tokens
if token_type_ids is not None:
token_type_ids = F.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0
if position_ids is not None:
# pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings
position_ids = F.pad(position_ids, (0, padding_len), value=pad_token_id)
......@@ -577,8 +888,23 @@ class LongformerModel(RobertaModel):
inputs_embeds_padding = self.embeddings(input_ids_padding)
inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2)
attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens
token_type_ids = F.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0
return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor):
# longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
# (global_attention_mask + 1) => 1 for local attention, 2 for global attention
# => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
if attention_mask is not None:
attention_mask = attention_mask * (global_attention_mask + 1)
else:
# simply use `global_attention_mask` as `attention_mask`
# if no `attention_mask` is given
attention_mask = global_attention_mask + 1
return attention_mask
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
......@@ -634,24 +960,25 @@ class LongformerModel(RobertaModel):
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# padding
attention_window = (
self.config.attention_window
if isinstance(self.config.attention_window, int)
else max(self.config.attention_window)
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# merge `global_attention_mask` and `attention_mask`
if global_attention_mask is not None:
# longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
# (global_attention_mask + 1) => 1 for local attention, 2 for global attention
# => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
if attention_mask is not None:
attention_mask = attention_mask * (global_attention_mask + 1)
else:
# simply use `global_attention_mask` as `attention_mask`
# if no `attention_mask` is given
attention_mask = global_attention_mask + 1
attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size(
input_ids=input_ids,
......@@ -659,23 +986,29 @@ class LongformerModel(RobertaModel):
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
attention_window=attention_window,
pad_token_id=self.config.pad_token_id,
)
# embed
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=None,
inputs_embeds=inputs_embeds,
encoder_hidden_states=None,
encoder_attention_mask=None,
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + encoder_outputs[
1:
] # add hidden_states and attentions if they are here
# undo padding
if padding_len > 0:
......@@ -684,13 +1017,13 @@ class LongformerModel(RobertaModel):
# `pooled_output`: independent of the sequence length
# `hidden_states`: mainly used for debugging and analysis, so keep the padding
# `attentions`: mainly used for debugging and analysis, so keep the padding
output = output[0][:, :-padding_len], *output[1:]
outputs = outputs[0][:, :-padding_len], *outputs[1:]
return output
return outputs
@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING)
class LongformerForMaskedLM(BertPreTrainedModel):
class LongformerForMaskedLM(LongformerPreTrainedModel):
config_class = LongformerConfig
base_model_prefix = "longformer"
......
......@@ -811,7 +811,7 @@ class ModelTesterMixin:
# Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model)
with torch.no_grad():
_ = model(**inputs_dict)
_ = model(**self._prepare_for_class(inputs_dict, model_class))
global_rng = random.Random()
......
......@@ -115,6 +115,18 @@ class LongformerModelTester:
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_attention_mask_determinism(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LongformerModel(config=config)
model.to(torch_device)
model.eval()
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
output_with_mask = model(input_ids, attention_mask=attention_mask)[0]
output_without_mask = model(input_ids)[0]
self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4))
def create_and_check_longformer_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
......@@ -134,6 +146,36 @@ class LongformerModelTester:
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_longformer_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LongformerModel(config=config)
model.to(torch_device)
model.eval()
global_attention_mask = input_mask.clone()
global_attention_mask[:, input_mask.shape[-1] // 2] = 0
global_attention_mask = global_attention_mask.to(torch_device)
sequence_output, pooled_output = model(
input_ids,
attention_mask=input_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
)
sequence_output, pooled_output = model(
input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask
)
sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_longformer_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
......@@ -243,7 +285,13 @@ class LongformerModelTester:
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
global_attention_mask = torch.zeros_like(input_ids)
inputs_dict = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": input_mask,
"global_attention_mask": global_attention_mask,
}
return config, inputs_dict
def prepare_config_and_inputs_for_question_answering(self):
......@@ -277,11 +325,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
(
LongformerModel,
LongformerForMaskedLM,
# TODO: make tests pass for those models
# LongformerForSequenceClassification,
# LongformerForQuestionAnswering,
# LongformerForTokenClassification,
# LongformerForMultipleChoice,
LongformerForSequenceClassification,
LongformerForQuestionAnswering,
LongformerForTokenClassification,
LongformerForMultipleChoice,
)
if is_torch_available()
else ()
......@@ -298,6 +345,14 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model(*config_and_inputs)
def test_longformer_model_attention_mask_determinism(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
def test_longformer_model_global_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs)
def test_longformer_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
......@@ -325,15 +380,31 @@ class LongformerModelIntegrationTest(unittest.TestCase):
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
model.to(torch_device)
# 'Hello world!'
input_ids = torch.tensor([[0, 20920, 232, 328, 1437, 2]], dtype=torch.long, device=torch_device)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
output = model(input_ids, attention_mask=attention_mask)[0]
output_without_mask = model(input_ids)[0]
expected_output_slice = torch.tensor([0.0549, 0.1087, -0.1119, -0.0368, 0.0250], device=torch_device)
self.assertTrue(torch.allclose(output[0, 0, -5:], expected_output_slice, atol=1e-4))
self.assertTrue(torch.allclose(output_without_mask[0, 0, -5:], expected_output_slice, atol=1e-4))
@slow
def test_inference_no_head_long(self):
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
model.to(torch_device)
# 'Hello world! ' repeated 1000 times
input_ids = torch.tensor(
[[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=torch.long, device=torch_device
) # long input
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
attention_mask[:, [1, 4, 21]] = 2 # Set global attention on a few random positions
global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=input_ids.device)
global_attention_mask[:, [1, 4, 21]] = 1 # Set global attention on a few random positions
output = model(input_ids, attention_mask=attention_mask)[0]
output = model(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)[0]
expected_output_sum = torch.tensor(74585.8594, device=torch_device)
expected_output_mean = torch.tensor(0.0243, device=torch_device)
......@@ -341,7 +412,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4))
@slow
def test_inference_masked_lm(self):
def test_inference_masked_lm_long(self):
model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
model.to(torch_device)
......@@ -352,9 +423,9 @@ class LongformerModelIntegrationTest(unittest.TestCase):
loss, prediction_scores = model(input_ids, labels=input_ids)
expected_loss = torch.tensor(0.0620, device=torch_device)
expected_prediction_scores_sum = torch.tensor(-6.1599e08, device=torch_device)
expected_prediction_scores_mean = torch.tensor(-3.0622, device=torch_device)
expected_loss = torch.tensor(0.0074, device=torch_device)
expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device)
expected_prediction_scores_mean = torch.tensor(-3.0348, device=torch_device)
input_ids = input_ids.to(torch_device)
self.assertTrue(torch.allclose(loss, expected_loss, atol=1e-4))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment