Unverified Commit 3d0c0ae4 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix longformer onnx broken export (#20292)

* fix controlflow for onnx export

* fix warning

* fix the case padding_len = 0, explicit the recorded control flows

* style

* style

* fix bug

* fix copy

* nits
parent 9ef46659
......@@ -392,7 +392,7 @@ class LEDEncoderSelfAttention(nn.Module):
return chunked_hidden_states
@staticmethod
def _chunk(hidden_states, window_overlap, onnx_export=False):
def _chunk(hidden_states, window_overlap, onnx_export: bool = False):
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
if not onnx_export:
# non-overlapping chunks of size = 2w
......@@ -411,29 +411,26 @@ class LEDEncoderSelfAttention(nn.Module):
return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
# When exporting to ONNX, use this separate logic
if hidden_states.size(1) == window_overlap * 2:
# simplest case
return hidden_states.unsqueeze(1)
else:
# have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export
# TODO replace this with
# > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3)
# once `unfold` is supported
# have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export
# TODO replace this with
# > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3)
# once `unfold` is supported
# the case hidden_states.size(1) == window_overlap * 2 can also simply return hidden_states.unsqueeze(1), but that's control flow
chunk_size = [
hidden_states.size(0),
torch.div(hidden_states.size(1), window_overlap, rounding_mode="trunc") - 1,
window_overlap * 2,
hidden_states.size(2),
]
chunk_size = [
hidden_states.size(0),
hidden_states.size(1) // window_overlap - 1,
window_overlap * 2,
hidden_states.size(2),
overlapping_chunks = torch.empty(chunk_size)
for chunk in range(chunk_size[1]):
overlapping_chunks[:, chunk, :, :] = hidden_states[
:, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, :
]
overlapping_chunks = torch.empty(chunk_size)
for chunk in range(chunk_size[1]):
overlapping_chunks[:, chunk, :, :] = hidden_states[
:, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, :
]
return overlapping_chunks
return overlapping_chunks
@staticmethod
def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor:
......
......@@ -221,7 +221,10 @@ class LongformerOnnxConfig(OnnxConfig):
)
import torch
# for some reason, replacing this code by inputs["global_attention_mask"] = torch.randint(2, inputs["input_ids"].shape, dtype=torch.int64)
# makes the export fail randomly
inputs["global_attention_mask"] = torch.zeros_like(inputs["input_ids"])
# make every second token global
inputs["global_attention_mask"][:, ::2] = 1
return inputs
......@@ -763,7 +763,7 @@ class LongformerSelfAttention(nn.Module):
return chunked_hidden_states
@staticmethod
def _chunk(hidden_states, window_overlap, onnx_export=False):
def _chunk(hidden_states, window_overlap, onnx_export: bool = False):
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
if not onnx_export:
# non-overlapping chunks of size = 2w
......@@ -782,29 +782,26 @@ class LongformerSelfAttention(nn.Module):
return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
# When exporting to ONNX, use this separate logic
if hidden_states.size(1) == window_overlap * 2:
# simplest case
return hidden_states.unsqueeze(1)
else:
# have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export
# TODO replace this with
# > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3)
# once `unfold` is supported
# have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export
# TODO replace this with
# > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3)
# once `unfold` is supported
# the case hidden_states.size(1) == window_overlap * 2 can also simply return hidden_states.unsqueeze(1), but that's control flow
chunk_size = [
hidden_states.size(0),
torch.div(hidden_states.size(1), window_overlap, rounding_mode="trunc") - 1,
window_overlap * 2,
hidden_states.size(2),
]
chunk_size = [
hidden_states.size(0),
hidden_states.size(1) // window_overlap - 1,
window_overlap * 2,
hidden_states.size(2),
overlapping_chunks = torch.empty(chunk_size)
for chunk in range(chunk_size[1]):
overlapping_chunks[:, chunk, :, :] = hidden_states[
:, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, :
]
overlapping_chunks = torch.empty(chunk_size)
for chunk in range(chunk_size[1]):
overlapping_chunks[:, chunk, :, :] = hidden_states[
:, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, :
]
return overlapping_chunks
return overlapping_chunks
@staticmethod
def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor:
......@@ -1291,9 +1288,10 @@ class LongformerEncoder(nn.Module):
output_hidden_states=False,
return_dict=True,
):
is_index_masked = attention_mask < 0
is_index_global_attn = attention_mask > 0
# Record `is_global_attn == True` to enable ONNX export
is_global_attn = is_index_global_attn.flatten().any().item()
all_hidden_states = () if output_hidden_states else None
......@@ -1349,15 +1347,14 @@ class LongformerEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# undo padding
if padding_len > 0:
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
hidden_states = hidden_states[:, :-padding_len]
if output_hidden_states:
all_hidden_states = tuple([state[:, :-padding_len] for state in all_hidden_states])
# undo padding if necessary
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
hidden_states = hidden_states[:, : hidden_states.shape[1] - padding_len]
if output_hidden_states:
all_hidden_states = tuple([state[:, : state.shape[1] - padding_len] for state in all_hidden_states])
if output_attentions:
all_attentions = tuple([state[:, :, :-padding_len, :] for state in all_attentions])
if output_attentions:
all_attentions = tuple([state[:, :, : state.shape[2] - padding_len, :] for state in all_attentions])
if not return_dict:
return tuple(
......@@ -1612,6 +1609,8 @@ class LongformerModel(LongformerPreTrainedModel):
batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seq_len % attention_window) % attention_window
# this path should be recorded in the ONNX export, it is fine with padding_len == 0 as well
if padding_len > 0:
logger.info(
f"Input ids are automatically padded from {seq_len} to {seq_len + padding_len} to be a multiple of "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment