Unverified Commit 0fe17f37 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

FX tracing improvement (#14321)

* Change the way tracing happens, enabling dynamic axes out of the box

* Update the tests and modeling xlnet

* Add the non recoding of leaf modules to avoid recording more values for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors).

* Comments and making tracing work for gpt-j and xlnet

* Refactore things related to num_choices (and batch_size, sequence_length)

* Update fx to work on PyTorch 1.10

* Postpone autowrap_function feature usage for later

* Add copyrights

* Remove unnecessary file

* Fix issue with add_new_model_like

* Apply suggestions
parent 552f8d30
......@@ -1189,6 +1189,16 @@ def create_new_model_like(
if "tokenization" not in str(f) and "processor" not in str(f) and "feature_extraction" not in str(f)
]
def disable_fx_test(filename: Path) -> bool:
with open(filename) as fp:
content = fp.read()
new_content = re.sub(r"fx_compatible\s*=\s*True", "fx_compatible = False", content)
with open(filename, "w") as fp:
fp.write(new_content)
return content != new_content
disabled_fx_test = False
for test_file in files_to_adapt:
new_test_file_name = test_file.name.replace(
old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
......@@ -1201,6 +1211,13 @@ def create_new_model_like(
dest_file=dest_file,
add_copied_from=False,
)
disabled_fx_test = disabled_fx_test | disable_fx_test(dest_file)
if disabled_fx_test:
print(
"The tests for symbolic tracing with torch.fx were disabled, you can add those once symbolic tracing works "
"for your new model."
)
# 4. Add model to auto classes
add_model_to_auto_classes(old_model_patterns, new_model_patterns, model_classes)
......
......@@ -322,7 +322,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOIN
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
TORCH_FX_REQUIRED_VERSION = version.parse("1.9")
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8")
_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False
......
......@@ -247,6 +247,27 @@ class ModuleUtilsMixin:
return encoder_extended_attention_mask
def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask, device):
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
causal_mask,
],
axis=-1,
)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
return extended_attention_mask
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
......@@ -271,26 +292,9 @@ class ModuleUtilsMixin:
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[
torch.ones(
(batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype
),
causal_mask,
],
axis=-1,
)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
extended_attention_mask = self.create_extended_attention_mask_for_decoder(
input_shape, attention_mask, device
)
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
......@@ -1861,7 +1865,7 @@ class Conv1D(nn.Module):
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
x = x.view(size_out)
return x
......
......@@ -293,7 +293,7 @@ class AlbertAttention(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def prune_heads(self, heads):
......
......@@ -252,7 +252,7 @@ class BertSelfAttention(nn.Module):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
......@@ -341,7 +341,7 @@ class BertSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
......@@ -245,7 +245,7 @@ class ElectraSelfAttention(nn.Module):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
......@@ -334,7 +334,7 @@ class ElectraSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
......@@ -193,7 +193,7 @@ class GPT2Attention(nn.Module):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
attn_weights = attn_weights / (value.size(-1) ** 0.5)
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
......@@ -281,7 +281,7 @@ class GPT2Attention(nn.Module):
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def _merge_heads(self, tensor, num_heads, attn_head_size):
......@@ -915,7 +915,7 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(*output_shape)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
......@@ -1410,7 +1410,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
loss = None
if labels is not None:
......
......@@ -173,7 +173,7 @@ class GPTNeoSelfAttention(nn.Module):
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def _merge_heads(self, tensor, num_heads, attn_head_size):
......@@ -637,7 +637,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(*output_shape)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
......@@ -891,7 +891,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
loss = None
if labels is not None:
......
......@@ -107,7 +107,7 @@ class GPTJAttention(nn.Module):
Splits hidden dim into attn_head_size and num_attention_heads
"""
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
tensor = tensor.view(*new_shape)
tensor = tensor.view(new_shape)
if rotary:
return tensor
if len(tensor.shape) == 5:
......@@ -665,7 +665,7 @@ class GPTJModel(GPTJPreTrainedModel):
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(*output_shape)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
......@@ -945,7 +945,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
loss = None
if labels is not None:
......
......@@ -160,7 +160,7 @@ class LayoutLMSelfAttention(nn.Module):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
......@@ -249,7 +249,7 @@ class LayoutLMSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
......@@ -223,7 +223,7 @@ class MegatronBertSelfAttention(nn.Module):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
......@@ -312,7 +312,7 @@ class MegatronBertSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
......@@ -237,7 +237,7 @@ class MobileBertSelfAttention(nn.Module):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
......@@ -274,7 +274,7 @@ class MobileBertSelfAttention(nn.Module):
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
......
......@@ -260,7 +260,7 @@ class RealmSelfAttention(nn.Module):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
......@@ -349,7 +349,7 @@ class RealmSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
......@@ -187,7 +187,7 @@ class RobertaSelfAttention(nn.Module):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
......@@ -276,7 +276,7 @@ class RobertaSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
......@@ -127,7 +127,7 @@ class SplinterSelfAttention(nn.Module):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
......@@ -216,7 +216,7 @@ class SplinterSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
......@@ -181,7 +181,7 @@ class XLMRobertaXLSelfAttention(nn.Module):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
......@@ -270,7 +270,7 @@ class XLMRobertaXLSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
import copy
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
import math
import random
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from types import ModuleType
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
import torch
from packaging import version
......@@ -26,17 +42,11 @@ from .. import (
GPT2DoubleHeadsModel,
PretrainedConfig,
PreTrainedModel,
XLNetForQuestionAnswering,
logging,
)
from ..file_utils import TORCH_FX_REQUIRED_VERSION, importlib_metadata, is_torch_fx_available
from ..models.auto import get_values
from .fx_transformations import (
_cache_attributes,
_patch_arguments_,
_restore_attributes_,
transform_to_dynamic_input_,
transformation,
)
logger = logging.get_logger(__name__)
......@@ -46,6 +56,7 @@ def _generate_supported_model_classes(
model_name: Type[PretrainedConfig],
supported_tasks: Optional[Union[str, List[str]]] = None,
) -> List[Type[PreTrainedModel]]:
model_config_class = CONFIG_MAPPING[model_name]
task_mapping = {
"default": MODEL_MAPPING,
......@@ -86,15 +97,10 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"gptj",
"gpt_neo",
"t5",
]
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES = [
"albert",
"bert",
"distilbert",
"mobilebert",
"electra",
"megatron-bert",
"roberta",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
# "layoutlm",
# "xlnet",
]
_REGULAR_SUPPORTED_MODELS = []
......@@ -106,21 +112,11 @@ for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
_SPECIAL_SUPPORTED_MODELS = [
GPT2DoubleHeadsModel,
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
# XLNetForQuestionAnswering,
]
_SUPPORTED_MODELS = tuple(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)
_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = []
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES:
if isinstance(item, dict):
_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(**item))
else:
_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(item))
_SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = []
_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = tuple(
_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES + _SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES
)
class HFProxy(Proxy):
"""
......@@ -134,6 +130,7 @@ class HFProxy(Proxy):
if hasattr(self, "tracer") and self.tracer is not None:
self.device = self.tracer.root.device
self.dtype = next(self.tracer.root.parameters()).dtype
self.cache = None
@property
def shape(self):
......@@ -145,42 +142,54 @@ class HFProxy(Proxy):
def __contains__(self, key):
return False
def __eq__(self, other):
if self.cache is not None:
return self.cache == other
elif isinstance(other, HFProxy):
return True
else:
return super().__eq__(other)
def _wrap_method_for_model_recording(model, method_name, cache_name):
"""Helper function that wraps a torch.Tensor method to record its outputs during forward pass."""
method = getattr(torch.Tensor, method_name)
def __ne__(self, other):
return not self == other
@functools.wraps(method)
def wrapped(*args, **kwargs):
if not hasattr(model, cache_name):
setattr(model, cache_name, [])
cache = getattr(model, cache_name)
res = method(*args, **kwargs)
cache.append(res)
return res
def __len__(self):
if self.cache is not None:
if isinstance(self.cache, int):
return self.cache
elif isinstance(self.cache, (torch.Size, list, tuple)):
return len(self.cache)
else:
return super().__len__(self)
return super().__len__(self)
return wrapped
def __torch_function__(self, orig_method, types, args=None, kwargs=None):
proxy = super().__torch_function__(orig_method, types, args=args, kwargs=kwargs)
proxy.cache = self.cache
return proxy
def _create_recorded_proxy_method(proxy, method_name, cache_name):
"""
Helper function that sets a recorded torch.Tensor method as a HFProxy method that will use the recorded values
during symbolic tracing.
"""
def _function_to_leaf(func: Callable[..., Any]) -> Callable[..., Any]:
"""Wrapper that marks func as a leaf function, meaning that it will not be traced through by HFTracer."""
def method(self, *args, **kwargs):
cache = getattr(self.tracer.root, cache_name)
res = cache.pop(0)
return res
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
method.__name__ = method_name
bound_method = method.__get__(proxy, proxy.__class__)
setattr(proxy, method_name, bound_method)
def _function_leaf_getter(func_name: str, mapping: Dict[str, Callable[..., Any]]) -> Callable[..., Any]:
@functools.wraps(mapping[func_name])
def wrapper(*args, **kwargs):
return mapping[func_name](*args, **kwargs)
def _wrap_method_for_model_tracing(model, method_name, cache_name):
return wrapper
def _create_recorded_proxy_method(proxy: HFProxy, method_name: str, cache_name: str, return_proxy: bool):
"""
Helper function that sets a recorded torch.Tensor method as a torch.Tensor method that will use the recorded values
Helper function that sets a recorded torch.Tensor method as a HFProxy method that will use the recorded values
during symbolic tracing.
"""
......@@ -188,55 +197,69 @@ def _wrap_method_for_model_tracing(model, method_name, cache_name):
@functools.wraps(original_method)
def method(*args, **kwargs):
cache = getattr(model, cache_name)
cache = getattr(args[0].tracer.root, cache_name)
res = cache.pop(0)
if return_proxy:
proxy = args[0].__torch_function__(
original_method,
None,
args=args,
kwargs=kwargs,
)
proxy.cache = res
return proxy
return res
setattr(torch.Tensor, method_name, method)
if method_name == "size":
setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name)))
def _monkey_patch_tensor_methods_for_model_recording(model, method_names):
"""
Helper function that patches torch.Tensor methods (specified by the method_names list) to record model inference
before symbolic tracing.
"""
cache_names = dict()
original_methods = dict()
for method_name in method_names:
cache_name = f"cache_{method_name}"
cache_names[method_name] = cache_name
if not hasattr(torch.Tensor, method_name):
logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.")
continue
original_methods[method_name] = getattr(torch.Tensor, method_name)
setattr(torch.Tensor, method_name, _wrap_method_for_model_recording(model, method_name, cache_name))
if method_name == "size":
original_methods["shape"] = torch.Tensor.shape
setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name)))
return cache_names, original_methods
method.__name__ = method_name
bound_method = method.__get__(proxy, proxy.__class__)
setattr(proxy, method_name, bound_method)
def _reset_tensor_methods(original_methods):
def _reset_tensor_methods(original_methods: Dict[str, Callable[..., Any]]):
"""Helper function that resets the monkey patched torch.Tensor methods to their original values."""
for name, method in original_methods.items():
setattr(torch.Tensor, name, method)
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
if forbidden_values is None:
forbidden_values = []
value = random.randint(low, high)
while value in forbidden_values:
value = random.randint(low, high)
return value
class HFTracer(Tracer):
"""
Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the
regular PyTorch torch.fx.Proxy.
"""
default_methods_to_record = {"__bool__", "size", "dim"}
_DEFAULT_METHODS_TO_RECORD = {"__bool__": False, "size": True, "dim": False}
from transformers import modeling_utils
_FUNCTIONS_TO_AUTOWRAP = {
torch: {"arange", "zeros", "ones", "full_like", "eye"},
modeling_utils.ModuleUtilsMixin: {"create_extended_attention_mask_for_decoder"},
}
def __init__(self, autowrap_modules=(math,), autowrap_functions=(), enable_cpatching=False):
# Loading the leaf functions register
self._leaf_functions_register = {}
for module, names in self._FUNCTIONS_TO_AUTOWRAP.items():
for name in names:
self._register_leaf_function(module, name)
# TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer.
# autowrap_functions = autowrap_functions + tuple(
# patched for (_, _, patched) in self._leaf_functions_register.values()
# )
def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1):
super().__init__()
super().__init__(
autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions, enable_cpatching=enable_cpatching
)
if not is_torch_fx_available():
torch_version = version.parse(importlib_metadata.version("torch"))
......@@ -245,40 +268,107 @@ class HFTracer(Tracer):
f"{TORCH_FX_REQUIRED_VERSION} is supported."
)
encoder_sequence_length = sequence_length[0] if isinstance(sequence_length, (list, tuple)) else sequence_length
decoder_sequence_length = (
sequence_length[1] if isinstance(sequence_length, (list, tuple)) else encoder_sequence_length
)
self.encoder_shape = [batch_size, encoder_sequence_length]
self.decoder_shape = (
[batch_size, decoder_sequence_length] if decoder_sequence_length > 0 else list(self.encoder_shape)
)
self.num_choices = num_choices
if self.num_choices > 0:
self.encoder_shape = [batch_size, self.num_choices, encoder_sequence_length]
self.decoder_shape = [batch_size, self.num_choices, decoder_sequence_length]
self.prev_module = None
self.recorded_methods = None
def proxy(self, node: Node):
p = HFProxy(node, self)
if self.recorded_methods:
for method_name, cache_name in self.recorded_methods.items():
_create_recorded_proxy_method(p, method_name, cache_name)
return p
def _register_leaf_function(self, module: ModuleType, name: str):
"""Registers the function called name in module as a leaf function."""
orig_func = getattr(module, name)
patched_func = _function_to_leaf(orig_func)
patched_func.__module__ = __name__
self._leaf_functions_register[name] = (module, orig_func, patched_func)
def _patch_leaf_functions_for_root(self, root: PreTrainedModel, restore: bool = False):
"""Patches leaf functions specifically for root."""
for name in self._leaf_functions_register:
module, orig_func, patched_func = self._leaf_functions_register[name]
if restore:
root.__class__.forward.__globals__.pop(name)
setattr(module, name, orig_func)
else:
root.__class__.forward.__globals__[name] = patched_func
leaf_getter = _function_leaf_getter(name, root.__class__.forward.__globals__)
leaf_getter.__module__ = __name__
setattr(module, name, leaf_getter)
def _method_is_called_in_leaf_module(self, module_ids: List[int]) -> bool:
"""
Finds out if the method (that is being recorded) is called inside a leaf module, this allows to not record
outputs that will not be encountered by the tracer.
"""
currentframe = inspect.currentframe()
while currentframe:
if currentframe is None:
return False
module = currentframe.f_locals.get("self", None)
if id(module) in module_ids and self.is_leaf_module(module, "Not used anyway"):
return True
currentframe = currentframe.f_back
return False
def _wrap_method_for_model_recording(
self, model: PreTrainedModel, method_name: str, cache_name: str, module_ids: List[int]
):
"""Helper function that wraps a torch.Tensor method to record its outputs during forward pass."""
method = getattr(torch.Tensor, method_name)
@functools.wraps(method)
def wrapped(*args, **kwargs):
if self._method_is_called_in_leaf_module(module_ids):
return method(*args, **kwargs)
if not hasattr(model, cache_name):
setattr(model, cache_name, [])
cache = getattr(model, cache_name)
res = method(*args, **kwargs)
cache.append(res)
return res
return wrapped
def _monkey_patch_tensor_methods_for_model_recording(self, model: PreTrainedModel, method_names: Iterable[str]):
"""
Helper function that patches torch.Tensor methods (specified by the method_names list) to record model
inference before symbolic tracing.
"""
cache_names = {}
original_methods = {}
module_ids = set(id(mod) for mod in model.modules())
for method_name in method_names:
cache_name = f"cache_{method_name}"
cache_names[method_name] = cache_name
if not hasattr(torch.Tensor, method_name):
logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.")
continue
original_methods[method_name] = getattr(torch.Tensor, method_name)
setattr(
torch.Tensor,
method_name,
self._wrap_method_for_model_recording(model, method_name, cache_name, module_ids),
)
def _generate_dummy_input(self, model, input_name):
if method_name == "size":
original_methods["shape"] = torch.Tensor.shape
setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name)))
return cache_names, original_methods
def _generate_dummy_input(
self, model: PreTrainedModel, input_name: str, shape: List[int]
) -> Dict[str, torch.Tensor]:
"""Generates dummy input for model inference recording."""
model_class = model.__class__
device = model.device
inputs_dict = dict()
inputs_dict = {}
if input_name in ["labels", "start_positions", "end_positions"]:
batch_size = self.encoder_shape[0]
batch_size = shape[0]
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict["labels"] = torch.ones(batch_size, dtype=torch.long, device=device)
elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class in [
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING),
XLNetForQuestionAnswering,
]:
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class in [
......@@ -288,59 +378,56 @@ class HFTracer(Tracer):
]:
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class in [
*get_values(MODEL_FOR_PRETRAINING_MAPPING),
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
GPT2DoubleHeadsModel,
]:
inputs_dict["labels"] = torch.zeros(self.decoder_shape, dtype=torch.long, device=device)
elif model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros(self.encoder_shape, dtype=torch.long, device=device)
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
else:
raise NotImplementedError(f"{model_class} not supported yet.")
elif "mask" in input_name or "ids" in input_name:
shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape
inputs_dict[input_name] = torch.ones(shape, dtype=torch.long, device=device)
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
else:
shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape
shape += [model.config.hidden_size]
inputs_dict[input_name] = torch.ones(shape, dtype=torch.float, device=device)
shape_with_hidden_size = shape + [model.config.hidden_size]
inputs_dict[input_name] = torch.zeros(shape_with_hidden_size, dtype=torch.float, device=device)
return inputs_dict
def record(self, model, input_names, method_names=None):
def record(self, model: PreTrainedModel, input_names: List[str], method_names: Optional[Iterable[str]] = None):
"""
Records torch.Tensor method outputs (specified by the method_names list) that will then be used during symbolic
tracing.
Records torch.Tensor method outputs (specified by method_names) that will then be used during symbolic tracing.
"""
if method_names is None:
method_names = self.default_methods_to_record
method_names = self._DEFAULT_METHODS_TO_RECORD
# Creating a random input shape to generate dummy inputs.
batch_size = _generate_random_int()
sequence_length = _generate_random_int()
shape = [batch_size, sequence_length]
if model.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
num_choices = _generate_random_int(low=2, high=5)
shape.insert(1, num_choices)
inputs = {}
for input_name in input_names:
inputs.update(self._generate_dummy_input(model, input_name))
inputs.update(self._generate_dummy_input(model, input_name, shape))
clone = copy.deepcopy(model)
cache_names, original_methods = _monkey_patch_tensor_methods_for_model_recording(clone, method_names)
cache_names, original_methods = self._monkey_patch_tensor_methods_for_model_recording(model, method_names)
self.original_methods = original_methods
clone(**inputs)
# Useful because sometime the config is changed at inference time, for instance for
# classification tasks where config.problem_type can be set.
model.config = clone.config
model(**inputs)
_reset_tensor_methods(original_methods)
self.recorded_methods = {
method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(clone, cache_name)
method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(model, cache_name)
}
for cache_name in self.recorded_methods.values():
setattr(model, cache_name, getattr(clone, cache_name))
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if isinstance(attr_val, torch.nn.Parameter):
for n, p in self.root.named_parameters():
......@@ -357,7 +444,20 @@ class HFTracer(Tracer):
return parameter_proxy_cache[n]
return attr_val
def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None, method_names=None) -> Graph:
def proxy(self, node: Node):
p = HFProxy(node, self)
if self.recorded_methods:
for method_name, cache_name in self.recorded_methods.items():
return_proxy = self._DEFAULT_METHODS_TO_RECORD[method_name]
_create_recorded_proxy_method(p, method_name, cache_name, return_proxy)
return p
def trace(
self,
root: PreTrainedModel,
concrete_args: Optional[Dict[str, Any]] = None,
method_names: Optional[Iterable[str]] = None,
) -> Graph:
if concrete_args is None:
concrete_args = {}
......@@ -366,11 +466,16 @@ class HFTracer(Tracer):
self.record(root, input_names, method_names=method_names)
for method_name, cache_name in self.recorded_methods.items():
_wrap_method_for_model_tracing(root, method_name, cache_name)
# TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer.
autowrap_functions = [patched for (_, _, patched) in self._leaf_functions_register.values()]
self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions]))
self._patch_leaf_functions_for_root(root)
graph = super().trace(root, concrete_args=concrete_args)
self._patch_leaf_functions_for_root(root, restore=True)
_reset_tensor_methods(self.original_methods)
# TODO: keep this until necessary.
......@@ -388,7 +493,7 @@ class HFTracer(Tracer):
return graph
def _insert_module_as_submodule(self, mod):
def _insert_module_as_submodule(self, mod: nn.Module) -> str:
"""
Helper method which tries to insert a module that was not declared as submodule.
"""
......@@ -434,72 +539,19 @@ class HFTracer(Tracer):
self.prev_module = path
return path
def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
is_loss_module = m.__module__.startswith("torch.nn.modules.loss")
return (not is_loss_module) and super().is_leaf_module(m, module_qualified_name)
def create_arg(self, a: Any) -> Argument:
if isinstance(a, range):
return super().create_arg(list(a))
return super().create_arg(a)
@transformation
def prepare_for_retracing(gm: GraphModule) -> Tuple[GraphModule, Dict[str, Any]]:
"""
Prepares a GraphModule produced by symbolic_trace for retracing by:
- Caching all the attributes specific to the way the model was initially traced
- Patching back the model to a "static input shapes" version if it was traced to accept dynamic input shapes
For instance, the need to retrace a GraphModule can happen when applying quantization.
"""
attributes = _cache_attributes(gm)
_patch_arguments_(gm, gm.dynamic2static)
return gm, attributes
def restore_after_retracing_(gm: GraphModule, attributes: Dict[str, Any]):
"""Restores a GraphModule that was retraced to its initial state in terms of static / dynamic input shapes."""
_restore_attributes_(gm, attributes)
# transform_to_dynamic_input_ will override the static2dynamic and dynamic2static dictionaries which is the desired
# behaviour as the previously restored dictionaries contain nodes from the original GraphModule as values.
transform_to_dynamic_input_(gm, is_retracing=True)
_patch_arguments_(gm, gm.static2dynamic)
return gm
def retrace_graph_with(
gm: GraphModule, tracer: Tracer = None, func: Callable[[GraphModule], GraphModule] = None
) -> GraphModule:
"""
Retraces a GraphModule by either using a tracer or a function using a tracer (for instance
torch.quantization.quantize_fx.prepare_fx). It takes care of preparing the model for retracing, retracing it and
restoring anything necessary after the retrace.
"""
if tracer is None and func is None:
raise ValueError("Either a tracer or a function using a tracer must be provided.")
elif tracer is not None and func is not None:
raise ValueError("Either provide a tracer or a function using a tracer, but not both.")
else:
gm, attributes = prepare_for_retracing(gm)
tracing_func = tracer.trace if tracer else func
traced = tracing_func(gm)
restore_after_retracing_(traced, attributes)
return traced
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
if forbidden_values is None:
forbidden_values = []
value = random.randint(low, high)
while value in forbidden_values:
value = random.randint(low, high)
return value
def symbolic_trace(
model: PreTrainedModel,
input_names: Optional[List[str]] = None,
batch_size: int = 1,
sequence_length: Union[int, List[int], Tuple[int]] = (128, 128),
num_choices: int = -1,
) -> GraphModule:
"""
......@@ -510,89 +562,33 @@ def symbolic_trace(
The model to trace.
input_names (`List[str]`, *optional*):
The names of the inputs of the traced model. If unset, model.dummy_inputs().keys() are used instead.
batch_size (`int`, *optional*, defaults to 1):
The batch size of the traced model inputs.
sequence_length (`int` or `List[int]]`):
The sequence length of the traced model inputs. For sequence-to-sequence models with different sequence
lengths between the encoder and the decoder inputs, this must be `[encoder_sequence_length,
decoder_sequence_length]`.
num_choices (`int`, *optional*, defaults to -1):
The number of possible choices for a multiple choice task.
Returns:
`torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
Example:
```python
from transformers.utils.fx import symbolic_trace
```python
from transformers.utils.fx import symbolic_trace
traced_model = symbolic_trace(
model,
input_names=["input_ids", "attention_mask", "token_type_ids"],
batch_size=1,
sequence_length=128,
)
```"""
traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
```
"""
if input_names is None:
input_names = model.dummy_inputs.keys()
sig = inspect.signature(model.forward)
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
# Preparing HFTracer batch_size and sequence_lenght values for potential dynamic axes.
use_dynamic_batch_size = batch_size <= 0
if isinstance(sequence_length, (list, tuple)):
use_dynamic_sequence_length = sequence_length[0] <= 0 or sequence_length[1] <= 0
else:
use_dynamic_sequence_length = sequence_length <= 0
if use_dynamic_batch_size or use_dynamic_sequence_length:
forbidden_values = [
model.config.num_attention_heads,
model.config.hidden_size,
model.config.hidden_size // model.config.num_attention_heads,
]
if use_dynamic_batch_size:
batch_size = _generate_random_int(forbidden_values=forbidden_values)
forbidden_values.append(batch_size)
if use_dynamic_sequence_length:
encoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values)
forbidden_values.append(encoder_sequence_length)
decoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values)
sequence_length = [encoder_sequence_length, decoder_sequence_length]
if not isinstance(model, _SUPPORTED_MODELS):
supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS))
raise NotImplementedError(
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
)
if (use_dynamic_batch_size or use_dynamic_sequence_length) and not isinstance(
model, _SUPPORTED_MODELS_FOR_DYNAMIC_AXES
):
supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS_FOR_DYNAMIC_AXES))
raise NotImplementedError(
f"Dynamic axes are not supported for {model.__class__.__name__} yet, supported models: {supported_model_names}"
)
# Tracing.
tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices)
tracer = HFTracer()
traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph)
traced.config = copy.deepcopy(model.config)
traced.num_choices = num_choices
traced.dummy_inputs = {}
for name in input_names:
traced.dummy_inputs.update(tracer._generate_dummy_input(model, name))
traced.use_dynamic_batch_size = use_dynamic_batch_size
traced.use_dynamic_sequence_length = use_dynamic_sequence_length
traced.static_batch_size = batch_size
traced.static_sequence_length = sequence_length
transform_to_dynamic_input_(traced)
return traced
import copy
import functools
import operator
from inspect import signature
from typing import Any, Callable, Dict, Optional, Union
import torch
from torch.fx import Graph, GraphModule, Node
# Torch FX transformation convention:
# - transformations that are supposed to act on a copy of the original GraphModule are decorated with @transformation
# - transformations that are inplace have a name ending with "_"
def _cache_attributes(gm: GraphModule) -> Dict[str, Any]:
attributes_to_keep = [
"config",
"num_choices",
"dummy_inputs",
"use_dynamic_batch_size",
"use_dynamic_sequence_length",
"static_batch_size",
"static_sequence_length",
"static2dynamic",
"dynamic2static",
]
attributes = {k: getattr(gm, k, None) for k in attributes_to_keep}
return attributes
def _restore_attributes_(gm: GraphModule, attributes: Dict[str, Any]):
for name, attr in attributes.items():
setattr(gm, name, attr)
def deepcopy_graph(gm: GraphModule) -> GraphModule:
"""
Performs a deepcopy of the GraphModule while also copying the relevant attributes to know whether the model was
traced with dynamic axes, and what were the values if that is the case.
"""
# First, create a copy of the module without the graph.
graph = gm.__dict__.pop("_graph")
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(gm.__dict__)
gm.__dict__["_graph"] = graph
# Then, copy the graph.
val_map = {}
graph_clone = Graph()
output_val = graph_clone.graph_copy(graph, val_map=val_map)
graph_clone.output(output_val)
# Finally create a new GraphModule (or a subclass of GraphModule) from the module and the graph copies.
# gm.__class__ is used to take into account that gm can be an instance of a subclass of GraphModule.
clone = gm.__class__(fake_mod, graph_clone)
# Restore the dynamic axes related attributes to the clone.
attributes = _cache_attributes(gm)
attributes["dynamic2static"] = {val_map.get(k, k): v for k, v in attributes["dynamic2static"].items()}
attributes["static2dynamic"] = {v: k for k, v in attributes["dynamic2static"].items()}
_restore_attributes_(clone, attributes)
return clone
def transformation(func):
"""
Decorator that wraps a torch.fx transformation by feeding it a copy of the GraphModule to transform instead of the
original.
"""
def map_fn(arg):
if isinstance(arg, GraphModule):
return deepcopy_graph(arg)
return arg
@functools.wraps(func)
def wrapper(*args, **kwargs):
new_args = tuple(map_fn(arg) for arg in args)
new_kwargs = {k: map_fn(v) for k, v in kwargs.items()}
return func(*new_args, **new_kwargs)
wrapper._is_transformation = True
return wrapper
def compose_transformations(
*args: Callable[[GraphModule], Optional[GraphModule]], inplace: bool = False
) -> GraphModule:
"""
Allows to compose transformations together and takes of:
1. Performing the transformations on a copy of the GraphModule if inplace is set to False, transformations that
are decorated with @transformation (which means that they are not modifying the original GraphModule) are
unwrapped to make them inplace.
2. Linting and recompiling only at the end of the composition for performance purposes.
"""
args = list(args)
if not inplace:
args.insert(0, deepcopy_graph)
for i, transformation in enumerate(args[:-1]):
sig = signature(transformation)
# Unwrapping @transformation decorated transformations as performing the transformations inplace or on a copy is
# already handled by this function.
if getattr(transformation, "_is_transformation", False):
transformation = transformation.__wrapped__
# Linting and recompiling only after the last transformation applied to make composition efficient.
if "lint_and_recompile" in sig.parameters:
args[i] = functools.partial(transformation, lint_and_recompile=False)
def reduce_func(f, g):
def compose_f_and_g(gm):
output_g = g(gm)
if output_g is None:
output_g = gm
output_f = f(output_g)
if output_f is None:
output_f = gm
return output_f
return compose_f_and_g
return functools.reduce(reduce_func, reversed(args), lambda x: x)
def remove_unused_nodes_(gm: GraphModule, lint_and_recompile: bool = True):
"""Removes all the unused nodes in a GraphModule."""
graph = gm.graph
for node in graph.nodes:
if not node.users and node.op not in ["placeholder", "output"]:
graph.erase_node(node)
if lint_and_recompile:
graph.lint()
gm.recompile()
def _insert_batch_size_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node:
"""Inserts a node that retrieves the batch size dynamically from the input of the model."""
graph = gm.graph
input_names = set(gm.dummy_inputs.keys())
batch_size_node = None
for node in graph.nodes:
if node.op == "placeholder" and node.name in input_names:
with graph.inserting_after(node):
batch_size_node = graph.call_method("size", args=(node, 0))
if batch_size_node is None:
raise ValueError("Could not insert the node that computes the batch size")
if lint_and_recompile:
graph.lint()
gm.recompile()
# Useful when retracing for quantization.
if hasattr(gm, "_qconfig_map"):
gm._qconfig_map[batch_size_node.name] = None
return batch_size_node
def _insert_encoder_sequence_length_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node:
"""Inserts a node that retrieves the encoder sequence length dynamically from the input of the model."""
graph = gm.graph
input_names = set(gm.dummy_inputs.keys())
encoder_sequence_length_node = None
for node in graph.nodes:
if node.op == "placeholder" and node.name in input_names and "decoder" not in node.name:
with graph.inserting_after(node):
# There are two cases to handle:
# 1. num_choices < 0, meaning that the model is not performing a "multiple choice" task, in this case the
# input shapes is [batch_size, sequence_length] => index 1
# 2. num_choices > 0, meaning the model is performing a "multiple choice" task, in this case the input
# shape is [batch_size, num_choices, sequence_length] => index 2
encoder_sequence_length_node = graph.call_method("size", args=(node, 1 if gm.num_choices < 0 else 2))
if encoder_sequence_length_node is None:
raise ValueError("Could not insert the node that computes the encoder sequence length")
if lint_and_recompile:
graph.lint()
gm.recompile()
# Useful when retracing for quantization.
if hasattr(gm, "_qconfig_map"):
gm._qconfig_map[encoder_sequence_length_node.name] = None
return encoder_sequence_length_node
def _change_view_methods_(
gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
):
"""
Changes arguments of view ops that refer to static batch size / sequence lengths to make them refer to the
batch_size / sequence_length nodes.
"""
graph = gm.graph
for node in graph.nodes:
if node.op == "call_method" and node.target == "view":
if isinstance(node.args[1], tuple):
node.args = (node.args[0], *node.args[1])
node.args = tuple((mapping.get(arg, arg) for arg in node.args))
if lint_and_recompile:
graph.lint()
gm.recompile()
def _patch_getitem_(
gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
):
"""Patches getitem nodes by replacing current arguments to their corresponding values in mapping."""
# TODO: combine this with the patch_argument function which seems to do almost the same thing.
graph = gm.graph
for node in graph.nodes:
if node.op == "call_function" and node.target == operator.getitem:
indices = node.args[1]
if isinstance(indices, tuple):
new_indices = []
for idx in indices:
if isinstance(idx, slice):
new_indices.append(
slice(
mapping.get(idx.start, idx.start),
mapping.get(idx.stop, idx.stop),
mapping.get(idx.step, idx.step),
)
)
elif isinstance(idx, int):
new_indices.append(mapping.get(idx, idx))
else:
new_indices.append(idx)
node.args = (node.args[0], tuple(new_indices))
else:
node.args = (node.args[0], mapping.get(node.args[1], node.args[1]))
if lint_and_recompile:
graph.lint()
gm.recompile()
def _patch_arguments_(
gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
):
"""
Patches node by replacing their argument to their corresponding values in mapping (supports regular types, tuples
and slices).
"""
def _patch_slice(s, mapping):
return slice(mapping.get(s.start, s.start), mapping.get(s.stop, s.stop), mapping.get(s.step, s.step))
graph = gm.graph
supported_types = (Node, str, int, float)
for node in graph.nodes:
new_args = []
for arg in node.args:
if isinstance(arg, tuple):
new_arg = []
for a in arg:
if isinstance(a, slice):
new_arg.append(_patch_slice(a, mapping))
else:
new_arg.append(mapping.get(a, a))
new_args.append(tuple(new_arg))
elif isinstance(arg, slice):
new_args.append(_patch_slice(arg, mapping))
elif isinstance(arg, supported_types):
new_args.append(mapping.get(arg, arg))
else:
new_args.append(arg)
node.args = tuple(new_args)
if lint_and_recompile:
graph.lint()
gm.recompile()
def transform_to_dynamic_input_(gm: GraphModule, is_retracing: bool = False):
"""Transformation that enables traced models to perform inference on dynamic input shapes."""
graph = gm.graph
static2dynamic = {}
# Inserting the nodes that will fetch the batch size and sequence lengths dynamically.
if gm.use_dynamic_batch_size:
batch_size_node = _insert_batch_size_node_(gm, lint_and_recompile=False)
static2dynamic[gm.static_batch_size] = batch_size_node
if gm.num_choices > 0:
with graph.inserting_after(batch_size_node):
static2dynamic[gm.static_batch_size * gm.num_choices] = graph.call_function(
operator.mul, args=(batch_size_node, gm.num_choices)
)
# Useful when retracing for quantization.
if hasattr(gm, "_qconfig_map"):
gm._qconfig_map[static2dynamic[gm.static_batch_size * gm.num_choices]] = None
if gm.use_dynamic_sequence_length:
encoder_sequence_length_node = _insert_encoder_sequence_length_node_(gm, lint_and_recompile=False)
static2dynamic[gm.static_sequence_length[0]] = encoder_sequence_length_node
# TODO: do the same for the decoder.
pass
_change_view_methods_(gm, static2dynamic, lint_and_recompile=False)
_patch_getitem_(gm, static2dynamic, lint_and_recompile=False)
remove_unused_nodes_(gm, lint_and_recompile=False)
graph.lint()
gm.recompile()
gm.static2dynamic = static2dynamic
gm.dynamic2static = {v: k for (k, v) in static2dynamic.items()}
......@@ -231,8 +231,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
fx_compatible = True
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
......@@ -444,8 +444,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
else ()
)
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
fx_compatible = True
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment