Unverified Commit 3d0c0ae4 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix longformer onnx broken export (#20292)

* fix controlflow for onnx export

* fix warning

* fix the case padding_len = 0, explicit the recorded control flows

* style

* style

* fix bug

* fix copy

* nits
parent 9ef46659
...@@ -392,7 +392,7 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -392,7 +392,7 @@ class LEDEncoderSelfAttention(nn.Module):
return chunked_hidden_states return chunked_hidden_states
@staticmethod @staticmethod
def _chunk(hidden_states, window_overlap, onnx_export=False): def _chunk(hidden_states, window_overlap, onnx_export: bool = False):
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w""" """convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
if not onnx_export: if not onnx_export:
# non-overlapping chunks of size = 2w # non-overlapping chunks of size = 2w
...@@ -411,19 +411,16 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -411,19 +411,16 @@ class LEDEncoderSelfAttention(nn.Module):
return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
# When exporting to ONNX, use this separate logic # When exporting to ONNX, use this separate logic
if hidden_states.size(1) == window_overlap * 2:
# simplest case
return hidden_states.unsqueeze(1)
else:
# have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export # have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export
# TODO replace this with # TODO replace this with
# > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3) # > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3)
# once `unfold` is supported # once `unfold` is supported
# the case hidden_states.size(1) == window_overlap * 2 can also simply return hidden_states.unsqueeze(1), but that's control flow
chunk_size = [ chunk_size = [
hidden_states.size(0), hidden_states.size(0),
hidden_states.size(1) // window_overlap - 1, torch.div(hidden_states.size(1), window_overlap, rounding_mode="trunc") - 1,
window_overlap * 2, window_overlap * 2,
hidden_states.size(2), hidden_states.size(2),
] ]
......
...@@ -221,7 +221,10 @@ class LongformerOnnxConfig(OnnxConfig): ...@@ -221,7 +221,10 @@ class LongformerOnnxConfig(OnnxConfig):
) )
import torch import torch
# for some reason, replacing this code by inputs["global_attention_mask"] = torch.randint(2, inputs["input_ids"].shape, dtype=torch.int64)
# makes the export fail randomly
inputs["global_attention_mask"] = torch.zeros_like(inputs["input_ids"]) inputs["global_attention_mask"] = torch.zeros_like(inputs["input_ids"])
# make every second token global # make every second token global
inputs["global_attention_mask"][:, ::2] = 1 inputs["global_attention_mask"][:, ::2] = 1
return inputs return inputs
...@@ -763,7 +763,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -763,7 +763,7 @@ class LongformerSelfAttention(nn.Module):
return chunked_hidden_states return chunked_hidden_states
@staticmethod @staticmethod
def _chunk(hidden_states, window_overlap, onnx_export=False): def _chunk(hidden_states, window_overlap, onnx_export: bool = False):
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w""" """convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
if not onnx_export: if not onnx_export:
# non-overlapping chunks of size = 2w # non-overlapping chunks of size = 2w
...@@ -782,19 +782,16 @@ class LongformerSelfAttention(nn.Module): ...@@ -782,19 +782,16 @@ class LongformerSelfAttention(nn.Module):
return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
# When exporting to ONNX, use this separate logic # When exporting to ONNX, use this separate logic
if hidden_states.size(1) == window_overlap * 2:
# simplest case
return hidden_states.unsqueeze(1)
else:
# have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export # have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export
# TODO replace this with # TODO replace this with
# > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3) # > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3)
# once `unfold` is supported # once `unfold` is supported
# the case hidden_states.size(1) == window_overlap * 2 can also simply return hidden_states.unsqueeze(1), but that's control flow
chunk_size = [ chunk_size = [
hidden_states.size(0), hidden_states.size(0),
hidden_states.size(1) // window_overlap - 1, torch.div(hidden_states.size(1), window_overlap, rounding_mode="trunc") - 1,
window_overlap * 2, window_overlap * 2,
hidden_states.size(2), hidden_states.size(2),
] ]
...@@ -1291,9 +1288,10 @@ class LongformerEncoder(nn.Module): ...@@ -1291,9 +1288,10 @@ class LongformerEncoder(nn.Module):
output_hidden_states=False, output_hidden_states=False,
return_dict=True, return_dict=True,
): ):
is_index_masked = attention_mask < 0 is_index_masked = attention_mask < 0
is_index_global_attn = attention_mask > 0 is_index_global_attn = attention_mask > 0
# Record `is_global_attn == True` to enable ONNX export
is_global_attn = is_index_global_attn.flatten().any().item() is_global_attn = is_index_global_attn.flatten().any().item()
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
...@@ -1349,15 +1347,14 @@ class LongformerEncoder(nn.Module): ...@@ -1349,15 +1347,14 @@ class LongformerEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
# undo padding # undo padding if necessary
if padding_len > 0:
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1) # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
hidden_states = hidden_states[:, :-padding_len] hidden_states = hidden_states[:, : hidden_states.shape[1] - padding_len]
if output_hidden_states: if output_hidden_states:
all_hidden_states = tuple([state[:, :-padding_len] for state in all_hidden_states]) all_hidden_states = tuple([state[:, : state.shape[1] - padding_len] for state in all_hidden_states])
if output_attentions: if output_attentions:
all_attentions = tuple([state[:, :, :-padding_len, :] for state in all_attentions]) all_attentions = tuple([state[:, :, : state.shape[2] - padding_len, :] for state in all_attentions])
if not return_dict: if not return_dict:
return tuple( return tuple(
...@@ -1612,6 +1609,8 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -1612,6 +1609,8 @@ class LongformerModel(LongformerPreTrainedModel):
batch_size, seq_len = input_shape[:2] batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seq_len % attention_window) % attention_window padding_len = (attention_window - seq_len % attention_window) % attention_window
# this path should be recorded in the ONNX export, it is fine with padding_len == 0 as well
if padding_len > 0: if padding_len > 0:
logger.info( logger.info(
f"Input ids are automatically padded from {seq_len} to {seq_len + padding_len} to be a multiple of " f"Input ids are automatically padded from {seq_len} to {seq_len + padding_len} to be a multiple of "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment