Unverified Commit 8c7481f3 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

ViT and Swin symbolic tracing with torch.fx (#17182)

* Support tracing for ViT

* Swin support

* Fix copies

* Fix type annotation issue

* Removed unused import
parent 1a688709
...@@ -168,7 +168,7 @@ class DeiTSelfAttention(nn.Module): ...@@ -168,7 +168,7 @@ class DeiTSelfAttention(nn.Module):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward( def forward(
...@@ -200,7 +200,7 @@ class DeiTSelfAttention(nn.Module): ...@@ -200,7 +200,7 @@ class DeiTSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
...@@ -177,7 +177,7 @@ class DPTViTSelfAttention(nn.Module): ...@@ -177,7 +177,7 @@ class DPTViTSelfAttention(nn.Module):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward( def forward(
...@@ -209,7 +209,7 @@ class DPTViTSelfAttention(nn.Module): ...@@ -209,7 +209,7 @@ class DPTViTSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
...@@ -496,7 +496,7 @@ def window_reverse(windows, window_size, height, width): ...@@ -496,7 +496,7 @@ def window_reverse(windows, window_size, height, width):
""" """
Merges windows to produce higher resolution features. Merges windows to produce higher resolution features.
""" """
batch_size = int(windows.shape[0] / (height * width / window_size / window_size)) batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size))
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1) windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
return windows return windows
...@@ -697,7 +697,7 @@ class MaskFormerSwinSelfAttention(nn.Module): ...@@ -697,7 +697,7 @@ class MaskFormerSwinSelfAttention(nn.Module):
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward( def forward(
...@@ -750,7 +750,7 @@ class MaskFormerSwinSelfAttention(nn.Module): ...@@ -750,7 +750,7 @@ class MaskFormerSwinSelfAttention(nn.Module):
context_layer = torch.matmul(attention_probs, value_layer) context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
...@@ -226,7 +226,7 @@ def window_reverse(windows, window_size, height, width): ...@@ -226,7 +226,7 @@ def window_reverse(windows, window_size, height, width):
""" """
Merges windows to produce higher resolution features. Merges windows to produce higher resolution features.
""" """
batch_size = int(windows.shape[0] / (height * width / window_size / window_size)) batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size))
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1) windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
return windows return windows
...@@ -435,7 +435,7 @@ class SwinSelfAttention(nn.Module): ...@@ -435,7 +435,7 @@ class SwinSelfAttention(nn.Module):
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward( def forward(
...@@ -488,7 +488,7 @@ class SwinSelfAttention(nn.Module): ...@@ -488,7 +488,7 @@ class SwinSelfAttention(nn.Module):
context_layer = torch.matmul(attention_probs, value_layer) context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
...@@ -1071,7 +1071,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel): ...@@ -1071,7 +1071,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
# Reshape to (batch_size, num_channels, height, width) # Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output.transpose(1, 2) sequence_output = sequence_output.transpose(1, 2)
batch_size, num_channels, sequence_length = sequence_output.shape batch_size, num_channels, sequence_length = sequence_output.shape
height = width = int(sequence_length**0.5) height = width = math.floor(sequence_length**0.5)
sequence_output = sequence_output.reshape(batch_size, num_channels, height, width) sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
# Reconstruct pixel values # Reconstruct pixel values
......
...@@ -213,7 +213,7 @@ class ViTSelfAttention(nn.Module): ...@@ -213,7 +213,7 @@ class ViTSelfAttention(nn.Module):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward( def forward(
...@@ -245,7 +245,7 @@ class ViTSelfAttention(nn.Module): ...@@ -245,7 +245,7 @@ class ViTSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
...@@ -687,7 +687,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): ...@@ -687,7 +687,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
# Reshape to (batch_size, num_channels, height, width) # Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output[:, 1:] sequence_output = sequence_output[:, 1:]
batch_size, sequence_length, num_channels = sequence_output.shape batch_size, sequence_length, num_channels = sequence_output.shape
height = width = int(sequence_length**0.5) height = width = math.floor(sequence_length**0.5)
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
# Reconstruct pixel values # Reconstruct pixel values
......
...@@ -342,7 +342,7 @@ class ViTMAESelfAttention(nn.Module): ...@@ -342,7 +342,7 @@ class ViTMAESelfAttention(nn.Module):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward( def forward(
...@@ -374,7 +374,7 @@ class ViTMAESelfAttention(nn.Module): ...@@ -374,7 +374,7 @@ class ViTMAESelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
...@@ -280,7 +280,7 @@ class YolosSelfAttention(nn.Module): ...@@ -280,7 +280,7 @@ class YolosSelfAttention(nn.Module):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward( def forward(
...@@ -312,7 +312,7 @@ class YolosSelfAttention(nn.Module): ...@@ -312,7 +312,7 @@ class YolosSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
# limitations under the License. # limitations under the License.
import builtins import builtins
import collections
import functools import functools
import inspect import inspect
import math import math
import random import random
import warnings import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
import torch import torch
...@@ -31,6 +31,7 @@ from .. import ( ...@@ -31,6 +31,7 @@ from .. import (
CONFIG_MAPPING, CONFIG_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
...@@ -71,6 +72,7 @@ def _generate_supported_model_classes( ...@@ -71,6 +72,7 @@ def _generate_supported_model_classes(
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING, "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
} }
...@@ -100,6 +102,8 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ ...@@ -100,6 +102,8 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"gpt_neo", "gpt_neo",
"t5", "t5",
"roberta", "roberta",
"vit",
"swin",
# TODO: add support for them as it should be quite easy to do so (small blocking issues). # TODO: add support for them as it should be quite easy to do so (small blocking issues).
# "layoutlm", # "layoutlm",
# "xlnet", # "xlnet",
...@@ -276,6 +280,31 @@ def torch_tensor_index_select(self, dim, index): ...@@ -276,6 +280,31 @@ def torch_tensor_index_select(self, dim, index):
return torch_tensor_index_select(self, dim, index) return torch_tensor_index_select(self, dim, index)
def torch_roll(input, shifts, dims=None):
return input
def torch_nn_conv2d(self, input):
h_in, w_in = input.shape[-2:]
shape = None
padding = self.padding
if padding == "valid":
padding = (0, 0)
if padding == "same":
shape = list(input.shape)
if shape is None:
shape = list(input.shape)
h_out = math.floor(
(h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
)
w_out = math.floor(
(w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
)
shape[-2:] = [h_out, w_out]
shape[-3] = self.out_channels
return torch.empty(shape, device="meta")
def torch_nn_mseloss(self, input, target): def torch_nn_mseloss(self, input, target):
if self.reduction == "none": if self.reduction == "none":
shape = target.shape shape = target.shape
...@@ -317,9 +346,11 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = { ...@@ -317,9 +346,11 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.Tensor.mul: torch_tensor_mul_override, torch.Tensor.mul: torch_tensor_mul_override,
torch.matmul: torch_matmul_override, torch.matmul: torch_matmul_override,
torch.Tensor.repeat: torch_tensor_repeat_override, torch.Tensor.repeat: torch_tensor_repeat_override,
torch.roll: torch_roll,
# TODO: those might not be needed. # TODO: those might not be needed.
# torch.index_select: torch_index_select, # torch.index_select: torch_index_select,
# torch.Tensor.index_select: torch_tensor_index_select, # torch.Tensor.index_select: torch_tensor_index_select,
torch.nn.Conv2d: torch_nn_conv2d,
torch.nn.MSELoss: torch_nn_mseloss, torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss, torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss, torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
...@@ -368,6 +399,9 @@ class HFProxy(Proxy): ...@@ -368,6 +399,9 @@ class HFProxy(Proxy):
# we peephole optimize to the method invocation # we peephole optimize to the method invocation
return HFAttribute(self, k) return HFAttribute(self, k)
def __setitem__(self, indices, values):
return self.tracer.create_proxy("call_method", "__setitem__", (self, indices, values), {})
def __contains__(self, key): def __contains__(self, key):
# To handle cases such as : # To handle cases such as :
# `"some_key" in kwargs` # `"some_key" in kwargs`
...@@ -521,6 +555,15 @@ class HFTracer(Tracer): ...@@ -521,6 +555,15 @@ class HFTracer(Tracer):
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device) inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
else: else:
raise NotImplementedError(f"{model_class} not supported yet.") raise NotImplementedError(f"{model_class} not supported yet.")
elif "pixel_values" in input_name:
batch_size = shape[0]
image_size = model.config.image_size
if not isinstance(image_size, collections.abc.Iterable):
image_size = (image_size, image_size)
height, width = image_size
inputs_dict[input_name] = torch.zeros(
batch_size, model.config.num_channels, height, width, dtype=torch.float32, device=device
)
elif "mask" in input_name or "ids" in input_name: elif "mask" in input_name or "ids" in input_name:
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device) inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
...@@ -663,6 +706,11 @@ class HFTracer(Tracer): ...@@ -663,6 +706,11 @@ class HFTracer(Tracer):
else: else:
self.graph.erase_node(node) self.graph.erase_node(node)
# TODO: solves GraphModule creation.
# Without this, return type annotation "Tuple" is causing code execution failure.
if node.op == "output":
node.type = None
return self.graph return self.graph
def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool: def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
...@@ -761,12 +809,4 @@ def symbolic_trace( ...@@ -761,12 +809,4 @@ def symbolic_trace(
traced_graph = tracer.trace(model, concrete_args=concrete_args) traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph) traced = torch.fx.GraphModule(model, traced_graph)
# Copy all the original attributes to the traced GraphModule.
regular_module_attributes = dir(nn.Module())
for name in dir(model):
attr = getattr(model, name)
if name.startswith("_") or name in regular_module_attributes:
continue
setattr(traced, name, deepcopy(attr))
return traced return traced
...@@ -175,6 +175,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -175,6 +175,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
fx_compatible = True
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
......
...@@ -155,6 +155,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -155,6 +155,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
fx_compatible = True
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
......
...@@ -738,8 +738,7 @@ class ModelTesterMixin: ...@@ -738,8 +738,7 @@ class ModelTesterMixin:
traced_model = symbolic_trace(model, input_names) traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs) traced_output = traced_model(**filtered_inputs)
else: else:
input_names = ["input_ids", "attention_mask", "token_type_ids"] input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
input_ids = inputs["input_ids"]
labels = inputs.get("labels", None) labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None) start_positions = inputs.get("start_positions", None)
...@@ -756,12 +755,6 @@ class ModelTesterMixin: ...@@ -756,12 +755,6 @@ class ModelTesterMixin:
model_output = model(**filtered_inputs) model_output = model(**filtered_inputs)
rank = len(input_ids.shape)
if rank not in [2, 3]:
raise NotImplementedError(
f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}."
)
traced_model = symbolic_trace(model, input_names) traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs) traced_output = traced_model(**filtered_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment