Unverified Commit 3223d493 authored by Patrick Deutschmann's avatar Patrick Deutschmann Committed by GitHub
Browse files

Add ONNX support for Longformer (#17176)

* Implement ONNX support for Longformer

Fix repo consistency check complaints

Fix value mismatches

Add pooler output for default model

Increase validation atol to accommodate multiple-choice error

Fix copies

Fix chunking for longer sequence lengths

Add future comment

* Fix issue in mask_invalid_locations

* Remove torch imports in configuration_longformer

* Change config access to fix LED

* Push opset version to support tril

* Work in review comments (mostly style)

* Add Longformer to ONNX tests
parent c55d6e4e
...@@ -74,6 +74,7 @@ Ready-made configurations include the following architectures: ...@@ -74,6 +74,7 @@ Ready-made configurations include the following architectures:
- LayoutLM - LayoutLM
- LayoutLMv3 - LayoutLMv3
- LeViT - LeViT
- Longformer
- LongT5 - LongT5
- M2M100 - M2M100
- Marian - Marian
......
...@@ -160,6 +160,8 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -160,6 +160,8 @@ class LEDEncoderSelfAttention(nn.Module):
self.one_sided_attn_window_size = attention_window // 2 self.one_sided_attn_window_size = attention_window // 2
self.config = config
def forward( def forward(
self, self,
hidden_states, hidden_states,
...@@ -389,24 +391,48 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -389,24 +391,48 @@ class LEDEncoderSelfAttention(nn.Module):
return chunked_hidden_states return chunked_hidden_states
@staticmethod @staticmethod
def _chunk(hidden_states, window_overlap): def _chunk(hidden_states, window_overlap, onnx_export=False):
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w""" """convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
if not onnx_export:
# non-overlapping chunks of size = 2w
hidden_states = hidden_states.view(
hidden_states.size(0),
torch.div(hidden_states.size(1), (window_overlap * 2), rounding_mode="trunc"),
window_overlap * 2,
hidden_states.size(2),
)
# use `as_strided` to make the chunks overlap with an overlap size = window_overlap
chunk_size = list(hidden_states.size())
chunk_size[1] = chunk_size[1] * 2 - 1
chunk_stride = list(hidden_states.stride())
chunk_stride[1] = chunk_stride[1] // 2
return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
# When exporting to ONNX, use this separate logic
if hidden_states.size(1) == window_overlap * 2:
# simplest case
return hidden_states.unsqueeze(1)
else:
# have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export
# non-overlapping chunks of size = 2w # TODO replace this with
hidden_states = hidden_states.view( # > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3)
hidden_states.size(0), # once `unfold` is supported
hidden_states.size(1) // (window_overlap * 2),
window_overlap * 2,
hidden_states.size(2),
)
# use `as_strided` to make the chunks overlap with an overlap size = window_overlap chunk_size = [
chunk_size = list(hidden_states.size()) hidden_states.size(0),
chunk_size[1] = chunk_size[1] * 2 - 1 hidden_states.size(1) // window_overlap - 1,
window_overlap * 2,
hidden_states.size(2),
]
chunk_stride = list(hidden_states.stride()) overlapping_chunks = torch.empty(chunk_size)
chunk_stride[1] = chunk_stride[1] // 2 for chunk in range(chunk_size[1]):
return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) overlapping_chunks[:, chunk, :, :] = hidden_states[
:, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, :
]
return overlapping_chunks
@staticmethod @staticmethod
def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor: def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor:
...@@ -415,10 +441,14 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -415,10 +441,14 @@ class LEDEncoderSelfAttention(nn.Module):
ending_mask = beginning_mask.flip(dims=(1, 3)) ending_mask = beginning_mask.flip(dims=(1, 3))
beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
beginning_mask = beginning_mask.expand(beginning_input.size()) beginning_mask = beginning_mask.expand(beginning_input.size())
beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] = torch.full_like(
beginning_input, -float("inf")
).where(beginning_mask.bool(), beginning_input)
ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
ending_mask = ending_mask.expand(ending_input.size()) ending_mask = ending_mask.expand(ending_input.size())
ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] = torch.full_like(
ending_input, -float("inf")
).where(ending_mask.bool(), ending_input)
def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int): def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int):
""" """
...@@ -432,14 +462,14 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -432,14 +462,14 @@ class LEDEncoderSelfAttention(nn.Module):
), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}" ), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}"
assert query.size() == key.size() assert query.size() == key.size()
chunks_count = seq_len // window_overlap - 1 chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
query = self._chunk(query, window_overlap) query = self._chunk(query, window_overlap, self.config.__dict__.get("onnx_export", False))
key = self._chunk(key, window_overlap) key = self._chunk(key, window_overlap, self.config.__dict__.get("onnx_export", False))
# matrix multiplication # matrix multiplication
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
...@@ -457,7 +487,7 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -457,7 +487,7 @@ class LEDEncoderSelfAttention(nn.Module):
# window_overlap previous words). The following column is attention score from each word to itself, then # window_overlap previous words). The following column is attention score from each word to itself, then
# followed by window_overlap columns for the upper triangle. # followed by window_overlap columns for the upper triangle.
diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty( diagonal_attention_scores = diagonal_chunked_attention_scores.new_zeros(
(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1) (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1)
) )
...@@ -498,11 +528,14 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -498,11 +528,14 @@ class LEDEncoderSelfAttention(nn.Module):
assert seq_len % (window_overlap * 2) == 0 assert seq_len % (window_overlap * 2) == 0
assert attn_probs.size()[:3] == value.size()[:3] assert attn_probs.size()[:3] == value.size()[:3]
assert attn_probs.size(3) == 2 * window_overlap + 1 assert attn_probs.size(3) == 2 * window_overlap + 1
chunks_count = seq_len // window_overlap - 1 chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
chunked_attn_probs = attn_probs.transpose(1, 2).reshape( chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1 batch_size * num_heads,
torch.div(seq_len, window_overlap, rounding_mode="trunc"),
window_overlap,
2 * window_overlap + 1,
) )
# group batch_size and num_heads dimensions into one # group batch_size and num_heads dimensions into one
...@@ -577,9 +610,12 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -577,9 +610,12 @@ class LEDEncoderSelfAttention(nn.Module):
# (batch_size, seq_len, num_heads, max_num_global_attn_indices) # (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global))
# need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets
attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3)
attn_probs_from_global_key[ attn_probs_from_global_key[
is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1] is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, :
] = torch.finfo(attn_probs_from_global_key.dtype).min ] = torch.finfo(attn_probs_from_global_key.dtype).min
attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3)
return attn_probs_from_global_key return attn_probs_from_global_key
...@@ -673,9 +709,12 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -673,9 +709,12 @@ class LEDEncoderSelfAttention(nn.Module):
global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
# need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets
global_attn_scores = global_attn_scores.transpose(1, 2)
global_attn_scores[ global_attn_scores[
is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], : is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, :
] = torch.finfo(global_attn_scores.dtype).min ] = torch.finfo(global_attn_scores.dtype).min
global_attn_scores = global_attn_scores.transpose(1, 2)
global_attn_scores = global_attn_scores.masked_fill( global_attn_scores = global_attn_scores.masked_fill(
is_index_masked[:, None, None, :], is_index_masked[:, None, None, :],
......
...@@ -13,12 +13,20 @@ ...@@ -13,12 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Longformer configuration""" """ Longformer configuration"""
from typing import List, Union from collections import OrderedDict
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Union
from ...utils import logging from ...onnx import OnnxConfig
from ...utils import TensorType, logging
from ..roberta.configuration_roberta import RobertaConfig from ..roberta.configuration_roberta import RobertaConfig
if TYPE_CHECKING:
from ...configuration_utils import PretrainedConfig
from ...onnx.config import PatchingSpec
from ...tokenization_utils_base import PreTrainedTokenizerBase
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
...@@ -71,6 +79,69 @@ class LongformerConfig(RobertaConfig): ...@@ -71,6 +79,69 @@ class LongformerConfig(RobertaConfig):
```""" ```"""
model_type = "longformer" model_type = "longformer"
def __init__(self, attention_window: Union[List[int], int] = 512, sep_token_id: int = 2, **kwargs): def __init__(
self, attention_window: Union[List[int], int] = 512, sep_token_id: int = 2, onnx_export: bool = False, **kwargs
):
super().__init__(sep_token_id=sep_token_id, **kwargs) super().__init__(sep_token_id=sep_token_id, **kwargs)
self.attention_window = attention_window self.attention_window = attention_window
self.onnx_export = onnx_export
class LongformerOnnxConfig(OnnxConfig):
def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: "List[PatchingSpec]" = None):
super().__init__(config, task, patching_specs)
config.onnx_export = True
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("global_attention_mask", dynamic_axis),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
outputs = super().outputs
if self.task == "default":
outputs["pooler_output"] = {0: "batch"}
return outputs
@property
def atol_for_validation(self) -> float:
"""
What absolute tolerance value to use during model conversion validation.
Returns:
Float absolute tolerance value.
"""
return 1e-4
@property
def default_onnx_opset(self) -> int:
# needs to be >= 14 to support tril operator
return max(super().default_onnx_opset, 14)
def generate_dummy_inputs(
self,
tokenizer: "PreTrainedTokenizerBase",
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
inputs = super().generate_dummy_inputs(
preprocessor=tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
import torch
inputs["global_attention_mask"] = torch.zeros_like(inputs["input_ids"])
# make every second token global
inputs["global_attention_mask"][:, ::2] = 1
return inputs
...@@ -532,6 +532,8 @@ class LongformerSelfAttention(nn.Module): ...@@ -532,6 +532,8 @@ class LongformerSelfAttention(nn.Module):
self.one_sided_attn_window_size = attention_window // 2 self.one_sided_attn_window_size = attention_window // 2
self.config = config
def forward( def forward(
self, self,
hidden_states, hidden_states,
...@@ -761,24 +763,48 @@ class LongformerSelfAttention(nn.Module): ...@@ -761,24 +763,48 @@ class LongformerSelfAttention(nn.Module):
return chunked_hidden_states return chunked_hidden_states
@staticmethod @staticmethod
def _chunk(hidden_states, window_overlap): def _chunk(hidden_states, window_overlap, onnx_export=False):
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w""" """convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
if not onnx_export:
# non-overlapping chunks of size = 2w
hidden_states = hidden_states.view(
hidden_states.size(0),
torch.div(hidden_states.size(1), (window_overlap * 2), rounding_mode="trunc"),
window_overlap * 2,
hidden_states.size(2),
)
# use `as_strided` to make the chunks overlap with an overlap size = window_overlap
chunk_size = list(hidden_states.size())
chunk_size[1] = chunk_size[1] * 2 - 1
chunk_stride = list(hidden_states.stride())
chunk_stride[1] = chunk_stride[1] // 2
return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
# When exporting to ONNX, use this separate logic
if hidden_states.size(1) == window_overlap * 2:
# simplest case
return hidden_states.unsqueeze(1)
else:
# have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export
# non-overlapping chunks of size = 2w # TODO replace this with
hidden_states = hidden_states.view( # > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3)
hidden_states.size(0), # once `unfold` is supported
hidden_states.size(1) // (window_overlap * 2),
window_overlap * 2,
hidden_states.size(2),
)
# use `as_strided` to make the chunks overlap with an overlap size = window_overlap chunk_size = [
chunk_size = list(hidden_states.size()) hidden_states.size(0),
chunk_size[1] = chunk_size[1] * 2 - 1 hidden_states.size(1) // window_overlap - 1,
window_overlap * 2,
hidden_states.size(2),
]
chunk_stride = list(hidden_states.stride()) overlapping_chunks = torch.empty(chunk_size)
chunk_stride[1] = chunk_stride[1] // 2 for chunk in range(chunk_size[1]):
return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) overlapping_chunks[:, chunk, :, :] = hidden_states[
:, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, :
]
return overlapping_chunks
@staticmethod @staticmethod
def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor: def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor:
...@@ -787,10 +813,14 @@ class LongformerSelfAttention(nn.Module): ...@@ -787,10 +813,14 @@ class LongformerSelfAttention(nn.Module):
ending_mask = beginning_mask.flip(dims=(1, 3)) ending_mask = beginning_mask.flip(dims=(1, 3))
beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
beginning_mask = beginning_mask.expand(beginning_input.size()) beginning_mask = beginning_mask.expand(beginning_input.size())
beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] = torch.full_like(
beginning_input, -float("inf")
).where(beginning_mask.bool(), beginning_input)
ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
ending_mask = ending_mask.expand(ending_input.size()) ending_mask = ending_mask.expand(ending_input.size())
ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] = torch.full_like(
ending_input, -float("inf")
).where(ending_mask.bool(), ending_input)
def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int): def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int):
""" """
...@@ -804,14 +834,14 @@ class LongformerSelfAttention(nn.Module): ...@@ -804,14 +834,14 @@ class LongformerSelfAttention(nn.Module):
), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}" ), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}"
assert query.size() == key.size() assert query.size() == key.size()
chunks_count = seq_len // window_overlap - 1 chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
query = self._chunk(query, window_overlap) query = self._chunk(query, window_overlap, self.config.__dict__.get("onnx_export", False))
key = self._chunk(key, window_overlap) key = self._chunk(key, window_overlap, self.config.__dict__.get("onnx_export", False))
# matrix multiplication # matrix multiplication
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
...@@ -829,7 +859,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -829,7 +859,7 @@ class LongformerSelfAttention(nn.Module):
# window_overlap previous words). The following column is attention score from each word to itself, then # window_overlap previous words). The following column is attention score from each word to itself, then
# followed by window_overlap columns for the upper triangle. # followed by window_overlap columns for the upper triangle.
diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty( diagonal_attention_scores = diagonal_chunked_attention_scores.new_zeros(
(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1) (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1)
) )
...@@ -870,11 +900,14 @@ class LongformerSelfAttention(nn.Module): ...@@ -870,11 +900,14 @@ class LongformerSelfAttention(nn.Module):
assert seq_len % (window_overlap * 2) == 0 assert seq_len % (window_overlap * 2) == 0
assert attn_probs.size()[:3] == value.size()[:3] assert attn_probs.size()[:3] == value.size()[:3]
assert attn_probs.size(3) == 2 * window_overlap + 1 assert attn_probs.size(3) == 2 * window_overlap + 1
chunks_count = seq_len // window_overlap - 1 chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
chunked_attn_probs = attn_probs.transpose(1, 2).reshape( chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1 batch_size * num_heads,
torch.div(seq_len, window_overlap, rounding_mode="trunc"),
window_overlap,
2 * window_overlap + 1,
) )
# group batch_size and num_heads dimensions into one # group batch_size and num_heads dimensions into one
...@@ -949,9 +982,12 @@ class LongformerSelfAttention(nn.Module): ...@@ -949,9 +982,12 @@ class LongformerSelfAttention(nn.Module):
# (batch_size, seq_len, num_heads, max_num_global_attn_indices) # (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global))
# need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets
attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3)
attn_probs_from_global_key[ attn_probs_from_global_key[
is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1] is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, :
] = torch.finfo(attn_probs_from_global_key.dtype).min ] = torch.finfo(attn_probs_from_global_key.dtype).min
attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3)
return attn_probs_from_global_key return attn_probs_from_global_key
...@@ -1045,9 +1081,12 @@ class LongformerSelfAttention(nn.Module): ...@@ -1045,9 +1081,12 @@ class LongformerSelfAttention(nn.Module):
global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
# need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets
global_attn_scores = global_attn_scores.transpose(1, 2)
global_attn_scores[ global_attn_scores[
is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], : is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, :
] = torch.finfo(global_attn_scores.dtype).min ] = torch.finfo(global_attn_scores.dtype).min
global_attn_scores = global_attn_scores.transpose(1, 2)
global_attn_scores = global_attn_scores.masked_fill( global_attn_scores = global_attn_scores.masked_fill(
is_index_masked[:, None, None, :], is_index_masked[:, None, None, :],
...@@ -1588,7 +1627,7 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -1588,7 +1627,7 @@ class LongformerModel(LongformerPreTrainedModel):
inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2)
attention_mask = nn.functional.pad( attention_mask = nn.functional.pad(
attention_mask, (0, padding_len), value=False attention_mask, (0, padding_len), value=0
) # no attention on the padding tokens ) # no attention on the padding tokens
token_type_ids = nn.functional.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0 token_type_ids = nn.functional.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0
......
...@@ -358,6 +358,15 @@ class FeaturesManager: ...@@ -358,6 +358,15 @@ class FeaturesManager:
"seq2seq-lm-with-past", "seq2seq-lm-with-past",
onnx_config_cls="models.longt5.LongT5OnnxConfig", onnx_config_cls="models.longt5.LongT5OnnxConfig",
), ),
"longformer": supported_features_mapping(
"default",
"masked-lm",
"multiple-choice",
"question-answering",
"sequence-classification",
"token-classification",
onnx_config_cls="models.longformer.LongformerOnnxConfig",
),
"marian": supported_features_mapping( "marian": supported_features_mapping(
"default", "default",
"default-with-past", "default-with-past",
......
...@@ -212,6 +212,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -212,6 +212,7 @@ PYTORCH_EXPORT_MODELS = {
("data2vec-vision", "facebook/data2vec-vision-base"), ("data2vec-vision", "facebook/data2vec-vision-base"),
("perceiver", "deepmind/language-perceiver", ("masked-lm", "sequence-classification")), ("perceiver", "deepmind/language-perceiver", ("masked-lm", "sequence-classification")),
("perceiver", "deepmind/vision-perceiver-conv", ("image-classification",)), ("perceiver", "deepmind/vision-perceiver-conv", ("image-classification",)),
("longformer", "allenai/longformer-base-4096"),
("yolos", "hustvl/yolos-tiny"), ("yolos", "hustvl/yolos-tiny"),
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment