Unverified Commit c9f191a0 authored by Merve Noyan's avatar Merve Noyan Committed by GitHub
Browse files

Fix ONNX exports for Optimum compatible models (#31311)



* fixed models

* format with bumped ruff version on my local

* fix copies

* add tracing checks

* format

* Update src/transformers/utils/generic.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* format

* style fix

* Update modeling_mobilevit.py

* add docstring and change name

* Update __init__.py

* Update __init__.py

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent dc76e9fa
...@@ -37,6 +37,7 @@ from ...utils import ( ...@@ -37,6 +37,7 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
torch_int,
) )
from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig
...@@ -590,8 +591,10 @@ class ClapAudioLayer(nn.Module): ...@@ -590,8 +591,10 @@ class ClapAudioLayer(nn.Module):
def set_shift_and_window_size(self, input_resolution): def set_shift_and_window_size(self, input_resolution):
if min(input_resolution) <= self.window_size: if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows # if window size is larger than input resolution, we don't partition windows
self.shift_size = 0 self.shift_size = torch_int(0)
self.window_size = min(input_resolution) self.window_size = (
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
)
def get_attn_mask(self, height, width, dtype, device): def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0: if self.shift_size > 0:
......
...@@ -35,6 +35,7 @@ from ...utils import ( ...@@ -35,6 +35,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
logging, logging,
torch_int,
) )
from .configuration_donut_swin import DonutSwinConfig from .configuration_donut_swin import DonutSwinConfig
...@@ -562,8 +563,10 @@ class DonutSwinLayer(nn.Module): ...@@ -562,8 +563,10 @@ class DonutSwinLayer(nn.Module):
def set_shift_and_window_size(self, input_resolution): def set_shift_and_window_size(self, input_resolution):
if min(input_resolution) <= self.window_size: if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows # if window size is larger than input resolution, we don't partition windows
self.shift_size = 0 self.shift_size = torch_int(0)
self.window_size = min(input_resolution) self.window_size = (
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
)
def get_attn_mask(self, height, width, dtype, device): def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0: if self.shift_size > 0:
......
...@@ -39,7 +39,7 @@ from ...file_utils import ( ...@@ -39,7 +39,7 @@ from ...file_utils import (
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, logging from ...utils import ModelOutput, logging, torch_int
from ...utils.backbone_utils import load_backbone from ...utils.backbone_utils import load_backbone
from .configuration_dpt import DPTConfig from .configuration_dpt import DPTConfig
...@@ -226,7 +226,7 @@ class DPTViTEmbeddings(nn.Module): ...@@ -226,7 +226,7 @@ class DPTViTEmbeddings(nn.Module):
posemb_tok = posemb[:, :start_index] posemb_tok = posemb[:, :start_index]
posemb_grid = posemb[0, start_index:] posemb_grid = posemb[0, start_index:]
old_grid_size = int(math.sqrt(len(posemb_grid))) old_grid_size = torch_int(posemb_grid.size(0) ** 0.5)
posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2) posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear") posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")
......
...@@ -33,7 +33,13 @@ from ...modeling_outputs import ( ...@@ -33,7 +33,13 @@ from ...modeling_outputs import (
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_float,
)
from .configuration_imagegpt import ImageGPTConfig from .configuration_imagegpt import ImageGPTConfig
...@@ -229,7 +235,7 @@ class ImageGPTAttention(nn.Module): ...@@ -229,7 +235,7 @@ class ImageGPTAttention(nn.Module):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights: if self.scale_attn_weights:
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) attn_weights = attn_weights / torch_float(value.size(-1) ** 0.5)
# Layer-wise attention scaling # Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx: if self.scale_attn_by_inverse_layer_idx:
......
...@@ -33,7 +33,13 @@ from ...modeling_outputs import ( ...@@ -33,7 +33,13 @@ from ...modeling_outputs import (
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward from ...pytorch_utils import apply_chunking_to_forward
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_layoutlmv3 import LayoutLMv3Config from .configuration_layoutlmv3 import LayoutLMv3Config
...@@ -910,8 +916,8 @@ class LayoutLMv3Model(LayoutLMv3PreTrainedModel): ...@@ -910,8 +916,8 @@ class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
patch_height = patch_width = None patch_height = patch_width = None
if pixel_values is not None: if pixel_values is not None:
patch_height, patch_width = ( patch_height, patch_width = (
int(pixel_values.shape[2] / self.config.patch_size), torch_int(pixel_values.shape[2] / self.config.patch_size),
int(pixel_values.shape[3] / self.config.patch_size), torch_int(pixel_values.shape[3] / self.config.patch_size),
) )
visual_embeddings = self.forward_image(pixel_values) visual_embeddings = self.forward_image(pixel_values)
visual_attention_mask = torch.ones( visual_attention_mask = torch.ones(
......
...@@ -39,6 +39,7 @@ from ...utils import ( ...@@ -39,6 +39,7 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
torch_int,
) )
from .configuration_mobilevit import MobileViTConfig from .configuration_mobilevit import MobileViTConfig
...@@ -437,8 +438,16 @@ class MobileViTLayer(nn.Module): ...@@ -437,8 +438,16 @@ class MobileViTLayer(nn.Module):
batch_size, channels, orig_height, orig_width = features.shape batch_size, channels, orig_height, orig_width = features.shape
new_height = int(math.ceil(orig_height / patch_height) * patch_height) new_height = (
new_width = int(math.ceil(orig_width / patch_width) * patch_width) torch_int(torch.ceil(orig_height / patch_height) * patch_height)
if torch.jit.is_tracing()
else int(math.ceil(orig_height / patch_height) * patch_height)
)
new_width = (
torch_int(torch.ceil(orig_width / patch_width) * patch_width)
if torch.jit.is_tracing()
else int(math.ceil(orig_width / patch_width) * patch_width)
)
interpolate = False interpolate = False
if new_width != orig_width or new_height != orig_height: if new_width != orig_width or new_height != orig_height:
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
"""PyTorch SAM model.""" """PyTorch SAM model."""
import collections import collections
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
...@@ -232,7 +231,7 @@ class SamAttention(nn.Module): ...@@ -232,7 +231,7 @@ class SamAttention(nn.Module):
# SamAttention # SamAttention
_, _, _, c_per_head = query.shape _, _, _, c_per_head = query.shape
attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
attn = attn / math.sqrt(c_per_head) attn = attn / (c_per_head**0.5)
attn = torch.softmax(attn, dim=-1) attn = torch.softmax(attn, dim=-1)
if attention_similarity is not None: if attention_similarity is not None:
......
...@@ -36,6 +36,7 @@ from ...utils import ( ...@@ -36,6 +36,7 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
torch_int,
) )
from ...utils.backbone_utils import BackboneMixin from ...utils.backbone_utils import BackboneMixin
from .configuration_swin import SwinConfig from .configuration_swin import SwinConfig
...@@ -639,8 +640,10 @@ class SwinLayer(nn.Module): ...@@ -639,8 +640,10 @@ class SwinLayer(nn.Module):
def set_shift_and_window_size(self, input_resolution): def set_shift_and_window_size(self, input_resolution):
if min(input_resolution) <= self.window_size: if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows # if window size is larger than input resolution, we don't partition windows
self.shift_size = 0 self.shift_size = torch_int(0)
self.window_size = min(input_resolution) self.window_size = (
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
)
def get_attn_mask(self, height, width, dtype, device): def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0: if self.shift_size > 0:
......
...@@ -60,6 +60,8 @@ from .generic import ( ...@@ -60,6 +60,8 @@ from .generic import (
tensor_size, tensor_size,
to_numpy, to_numpy,
to_py_obj, to_py_obj,
torch_float,
torch_int,
transpose, transpose,
working_or_temp_dir, working_or_temp_dir,
) )
......
...@@ -753,6 +753,30 @@ def infer_framework(model_class): ...@@ -753,6 +753,30 @@ def infer_framework(model_class):
raise TypeError(f"Could not infer framework from class {model_class}.") raise TypeError(f"Could not infer framework from class {model_class}.")
def torch_int(x):
"""
Casts an input to a torch int64 tensor if we are in a tracing context, otherwise to a Python int.
"""
if not is_torch_available():
return int(x)
import torch
return x.to(torch.int64) if torch.jit.is_tracing() else int(x)
def torch_float(x):
"""
Casts an input to a torch float32 tensor if we are in a tracing context, otherwise to a Python float.
"""
if not is_torch_available():
return int(x)
import torch
return x.to(torch.float32) if torch.jit.is_tracing() else int(x)
def filter_out_non_signature_kwargs(extra: Optional[list] = None): def filter_out_non_signature_kwargs(extra: Optional[list] = None):
""" """
Decorator to filter out named arguments that are not in the function signature. Decorator to filter out named arguments that are not in the function signature.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment