Unverified Commit 2c2a2169 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

Fx with meta (#16836)

* Add meta proxy

* Uses meta data to trace data dependent control-flow

* Remove commented class

* Handles torch creating functions

* Added type annotation to fix tracing

* Tracing works for everything but T5 and GPT-J

* Almost all previously supported models pass

* All architectures can be traced except T5

* Intermediate commit to have a trace of the comparison operators for HFProxy

* Everything works, except loss computation

* Everything works

* Removed unused import

* Overriden methods do not use underlying ops (linear and torch.matmul), and model attributes are copied to the traced version

* Fix torch_matmul_override

* Change attributes reference to deepcopy

* Remove breakpoint and add torch_index_override

* Small fix

* Fix typo

* Replace asserts by explicit exceptions
parent ff846e9b
......@@ -850,7 +850,7 @@ class ElectraModel(ElectraPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithCrossAttentions]:
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -985,7 +985,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......@@ -1075,7 +1075,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, ElectraForPreTrainingOutput]:
) -> Union[Tuple[torch.Tensor], ElectraForPreTrainingOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see `input_ids` docstring)
......@@ -1197,7 +1197,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]:
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
......@@ -1283,7 +1283,7 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
......@@ -1368,7 +1368,7 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
......@@ -1469,7 +1469,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MultipleChoiceModelOutput]:
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
......@@ -1564,7 +1564,7 @@ class ElectraForCausalLM(ElectraPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
......
......@@ -508,7 +508,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -727,7 +727,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
......@@ -842,7 +842,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......
......@@ -919,7 +919,7 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
......@@ -1080,7 +1080,7 @@ class RobertaForMaskedLM(RobertaPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]:
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
......@@ -1193,7 +1193,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......@@ -1290,7 +1290,7 @@ class RobertaForMultipleChoice(RobertaPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MultipleChoiceModelOutput]:
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
......@@ -1390,7 +1390,7 @@ class RobertaForTokenClassification(RobertaPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
......@@ -1496,7 +1496,7 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
......@@ -13,18 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import functools
import inspect
import math
import random
from types import ModuleType
import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
import torch
from packaging import version
from torch import nn
from torch.fx import Graph, GraphModule, Node, Proxy, Tracer
from torch.fx.node import Argument
from torch.fx import Graph, GraphModule, Proxy, Tracer
from .. import (
CONFIG_MAPPING,
......@@ -121,107 +122,313 @@ _SUPPORTED_MODELS = tuple(
)
def embedding_override(self, input):
return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
def torch_nn_layernorm_override(self, input):
return input
def torch_nn_linear_override(self, input):
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
def torch_relu_override(x):
return x
def torch_nn_relu_override(self, x):
return x
def torch_nn_functional_relu_override(x, inplace=False):
if not inplace:
raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
return x
def torch_where_override(condition, x, y):
# torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
def torch_abs_override(input, *, out=None):
if out is None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
return input
def torch_arange_override(*args, **kwargs):
n = len(args)
step = 1
if n == 1:
start = 0
end = args[0]
elif n == 2:
start, end = args
else:
start, end, step = args
step = kwargs.get("step", step)
dtype = kwargs.get("dtype")
return torch.empty((end - start) // step, dtype=dtype, device="meta")
def torch_cat_override(tensors, dim=None, axis=None, *, out=None):
if dim is None and axis is None:
dim = 0
if dim is None and axis is not None:
dim = axis
if dim < 0:
dim = tensors[0].dim() + dim
shapes = [t.shape for t in tensors]
shape = list(shapes[0])
concatenated_dim = sum(shape[dim] for shape in shapes)
final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :]
return torch.empty(final_shape, device="meta")
def torch_stack_override(tensors, dim=None, axis=None, *, out=None):
if dim is None and axis is None:
dim = 0
if dim is None and axis is not None:
dim = axis
if dim < 0:
dim = tensors[0].dim() + 1 + dim
shape = list(tensors[0].shape)
shape.insert(dim, len(tensors))
return torch.empty(shape, device="meta")
def torch_add_override(input, other, *, alpha=1, out=None):
if not isinstance(input, torch.Tensor):
return torch.empty_like(other, device="meta")
if not isinstance(other, torch.Tensor):
return torch.empty_like(input, device="meta")
max_length = max(input.dim(), other.dim())
input_shape = list(input.shape) + [1] * (max_length - input.dim())
other_shape = list(other.shape) + [1] * (max_length - other.dim())
shape = []
for i in range(max_length):
shape.append(max(input_shape[i], other_shape[i]))
return torch.empty(shape, device="meta")
def torch_mul_override(input, other, *, out=None):
return torch_add_override(input, other, out=out)
def torch_tensor_mul_override(self, other):
return torch_mul_override(self, other)
def torch_matmul_override(input, other, *, out=None):
d1 = input.dim()
d2 = other.dim()
shape = None
if d1 == 1 and d2 == 1:
shape = None
elif d1 == 2 and d2 == 2:
shape = (input.size(0), other.size(1))
elif d1 == 1 and d2 == 2:
shape = (other.size(1),)
elif d1 == 2 and d1 == 1:
shape = (input.size(0),)
else:
max_length = max(input.dim(), other.dim())
shape1 = list(input.shape)
shape2 = list(other.shape)
if d1 == 1:
shape1 = [1] + shape1
if d2 == 1:
shape2.append(1)
shape1 = [-1] * (max_length - d1) + list(input.shape)
shape2 = [-1] * (max_length - d2) + list(other.shape)
shape = []
for i in range(max_length):
shape.append(max(shape1[i], shape2[i]))
shape[-2] = shape1[-2]
shape[-1] = shape2[-1]
if d1 == 1:
shape.pop(-2)
if d2 == 1:
shape.pop(-1)
if shape is None:
return torch.tensor(0.0, device="meta")
return torch.empty(*shape, device="meta")
def torch_tensor_repeat_override(self, *sizes):
shape = list(self.shape)
for i, x in enumerate(sizes):
shape[i] *= x
return torch.empty(shape, device="meta")
def torch_index_select(input, dim, index, *, out=None):
shape = list(input.shape)
shape[dim] = len(index)
return torch.empty(*shape, device="meta")
def torch_tensor_index_select(self, dim, index):
return torch_tensor_index_select(self, dim, index)
def torch_nn_mseloss(self, input, target):
if self.reduction == "none":
shape = target.shape
else:
shape = (1,)
return torch.empty(shape, device="meta")
def torch_nn_crossentropyloss(self, input, target):
if self.reduction == "none":
shape = target.shape
else:
shape = (1,)
return torch.empty(shape, device="meta")
def torch_nn_bcewithlogitsloss(self, input, target):
if self.reduction == "none":
shape = target.shape
else:
shape = (1,)
return torch.empty(shape, device="meta")
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.nn.Embedding: embedding_override,
torch.nn.LayerNorm: torch_nn_layernorm_override,
torch.nn.Linear: torch_nn_linear_override,
torch.relu: torch_relu_override,
torch.nn.functional.relu: torch_nn_functional_relu_override,
torch.nn.ReLU: torch_nn_relu_override,
torch.where: torch_where_override,
torch.abs: torch_abs_override,
torch.arange: torch_arange_override,
torch.cat: torch_cat_override,
torch.stack: torch_stack_override,
torch.add: torch_add_override,
torch.mul: torch_mul_override,
torch.Tensor.mul: torch_tensor_mul_override,
torch.matmul: torch_matmul_override,
torch.Tensor.repeat: torch_tensor_repeat_override,
# TODO: those might not be needed.
# torch.index_select: torch_index_select,
# torch.Tensor.index_select: torch_tensor_index_select,
torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
}
class HFProxy(Proxy):
"""
Proxy that is able to provide the proper ranks, shapes and boolean values during symbolic tracing by implementing
the dim, size and __bool__ methods. It can be easily extended by either adding new methods or extending the
existing ones.
Proxy that uses metadata to handle data-dependent control-flow.
"""
def __init__(self, node: Node, tracer: Optional[Tracer] = None):
super().__init__(node, tracer=tracer)
if hasattr(self, "tracer") and self.tracer is not None:
self.device = self.tracer.root.device
self.dtype = next(self.tracer.root.parameters()).dtype
self.cache = None
def install_metadata(self, metadata):
self._metadata = metadata
@property
def shape(self):
return self.size()
def __setitem__(self, key, value):
pass
def __contains__(self, key):
return False
return self.tracer.create_proxy("call_method", "size", (self,), {})
def __eq__(self, other):
if self.cache is not None:
return self.cache == other
elif isinstance(other, HFProxy):
return True
else:
return super().__eq__(other)
@property
def dtype(self):
return self.tracer.root.dtype
if hasattr(self, "_metadata") and self._metadata is not None:
return self._metadata.dtype
return self.tracer.create_proxy("call_function", builtins.getattr, (self, "dtype"), {})
def __ne__(self, other):
return not self == other
@property
def device(self):
# Hack so we can track when devices are used. During meta-tensor propagation,
# replace these values with a constant 'meta'
return MetaDeviceAttribute(self, "device")
def __len__(self):
if self.cache is not None:
if isinstance(self.cache, int):
return self.cache
elif isinstance(self.cache, (torch.Size, list, tuple)):
return len(self.cache)
else:
return super().__len__(self)
return super().__len__(self)
if hasattr(self, "_metadata") and self._metadata is not None:
return len(self._metadata)
return super().__len__()
def __bool__(self):
if hasattr(self, "_metadata") and self._metadata is not None:
return self._metadata
return super().__bool__()
def __getattr__(self, k):
if k == "_metadata":
return self.__getattribute__(k)
# note: not added to the graph yet, if this is a method call
# we peephole optimize to the method invocation
return HFAttribute(self, k)
def __torch_function__(self, orig_method, types, args=None, kwargs=None):
proxy = super().__torch_function__(orig_method, types, args=args, kwargs=kwargs)
proxy.cache = self.cache
return proxy
def __contains__(self, key):
# To handle cases such as :
# `"some_key" in kwargs`
if self.node.op == "placeholder":
return False
return super().__contains__(key)
def _function_to_leaf(func: Callable[..., Any]) -> Callable[..., Any]:
"""Wrapper that marks func as a leaf function, meaning that it will not be traced through by HFTracer."""
class HFAttribute(HFProxy):
def __init__(self, root, attr: str):
self.root = root
self.attr = attr
self.tracer = root.tracer
self._node = None
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
@property
def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
return self._node
return wrapper
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
def _function_leaf_getter(func_name: str, mapping: Dict[str, Callable[..., Any]]) -> Callable[..., Any]:
@functools.wraps(mapping[func_name])
def wrapper(*args, **kwargs):
return mapping[func_name](*args, **kwargs)
class MetaDeviceAttribute(HFAttribute):
pass
return wrapper
def _proxies_to_metas(v):
"""Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
if isinstance(v, MetaDeviceAttribute):
return "meta"
if isinstance(v, torch.fx.Proxy):
if not (isinstance(v, HFProxy) and hasattr(v, "_metadata")):
raise RuntimeError(f"No metadata was found for {v}")
return v._metadata
return v
def _create_recorded_proxy_method(proxy: HFProxy, method_name: str, cache_name: str, return_proxy: bool):
"""
Helper function that sets a recorded torch.Tensor method as a HFProxy method that will use the recorded values
during symbolic tracing.
"""
original_method = getattr(torch.Tensor, method_name)
@functools.wraps(original_method)
def method(*args, **kwargs):
cache = getattr(args[0].tracer.root, cache_name)
res = cache.pop(0)
if return_proxy:
proxy = args[0].__torch_function__(
original_method,
None,
args=args,
kwargs=kwargs,
)
proxy.cache = res
return proxy
return res
def _gen_constructor_wrapper(target):
@functools.wraps(target)
def wrapper(*args, **kwargs):
proxy = None
method.__name__ = method_name
bound_method = method.__get__(proxy, proxy.__class__)
setattr(proxy, method_name, bound_method)
def check_has_proxy(v):
if isinstance(v, Proxy):
nonlocal proxy
proxy = v
torch.fx.node.map_aggregate(args, check_has_proxy)
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
def _reset_tensor_methods(original_methods: Dict[str, Callable[..., Any]]):
"""Helper function that resets the monkey patched torch.Tensor methods to their original values."""
for name, method in original_methods.items():
setattr(torch.Tensor, name, method)
if proxy is not None:
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
else:
return target(*args, **kwargs)
return wrapper, target
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
......@@ -239,27 +446,11 @@ class HFTracer(Tracer):
regular PyTorch torch.fx.Proxy.
"""
_DEFAULT_METHODS_TO_RECORD = {"__bool__": False, "size": True, "dim": False}
from transformers import modeling_utils
_FUNCTIONS_TO_AUTOWRAP = {
torch: {"arange", "zeros", "ones", "full_like", "eye"},
modeling_utils.ModuleUtilsMixin: {"create_extended_attention_mask_for_decoder"},
}
allow_insert_stateless_mods: bool = True
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"]
def __init__(self, autowrap_modules=(math,), autowrap_functions=(), enable_cpatching=False):
# Loading the leaf functions register
self._leaf_functions_register = {}
for module, names in self._FUNCTIONS_TO_AUTOWRAP.items():
for name in names:
self._register_leaf_function(module, name)
# TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer.
# autowrap_functions = autowrap_functions + tuple(
# patched for (_, _, patched) in self._leaf_functions_register.values()
# )
super().__init__(
autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions, enable_cpatching=enable_cpatching
)
......@@ -271,91 +462,6 @@ class HFTracer(Tracer):
f"{TORCH_FX_REQUIRED_VERSION} is supported."
)
self.prev_module = None
self.recorded_methods = None
def _register_leaf_function(self, module: ModuleType, name: str):
"""Registers the function called name in module as a leaf function."""
orig_func = getattr(module, name)
patched_func = _function_to_leaf(orig_func)
patched_func.__module__ = __name__
self._leaf_functions_register[name] = (module, orig_func, patched_func)
def _patch_leaf_functions_for_root(self, root: PreTrainedModel, restore: bool = False):
"""Patches leaf functions specifically for root."""
for name in self._leaf_functions_register:
module, orig_func, patched_func = self._leaf_functions_register[name]
if restore:
root.__class__.forward.__globals__.pop(name)
setattr(module, name, orig_func)
else:
root.__class__.forward.__globals__[name] = patched_func
leaf_getter = _function_leaf_getter(name, root.__class__.forward.__globals__)
leaf_getter.__module__ = __name__
setattr(module, name, leaf_getter)
def _method_is_called_in_leaf_module(self, module_ids: List[int]) -> bool:
"""
Finds out if the method (that is being recorded) is called inside a leaf module, this allows to not record
outputs that will not be encountered by the tracer.
"""
currentframe = inspect.currentframe()
while currentframe:
if currentframe is None:
return False
module = currentframe.f_locals.get("self", None)
if id(module) in module_ids and self.is_leaf_module(module, "Not used anyway"):
return True
currentframe = currentframe.f_back
return False
def _wrap_method_for_model_recording(
self, model: PreTrainedModel, method_name: str, cache_name: str, module_ids: List[int]
):
"""Helper function that wraps a torch.Tensor method to record its outputs during forward pass."""
method = getattr(torch.Tensor, method_name)
@functools.wraps(method)
def wrapped(*args, **kwargs):
if self._method_is_called_in_leaf_module(module_ids):
return method(*args, **kwargs)
if not hasattr(model, cache_name):
setattr(model, cache_name, [])
cache = getattr(model, cache_name)
res = method(*args, **kwargs)
cache.append(res)
return res
return wrapped
def _monkey_patch_tensor_methods_for_model_recording(self, model: PreTrainedModel, method_names: Iterable[str]):
"""
Helper function that patches torch.Tensor methods (specified by the method_names list) to record model
inference before symbolic tracing.
"""
cache_names = {}
original_methods = {}
module_ids = set(id(mod) for mod in model.modules())
for method_name in method_names:
cache_name = f"cache_{method_name}"
cache_names[method_name] = cache_name
if not hasattr(torch.Tensor, method_name):
logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.")
continue
original_methods[method_name] = getattr(torch.Tensor, method_name)
setattr(
torch.Tensor,
method_name,
self._wrap_method_for_model_recording(model, method_name, cache_name, module_ids),
)
if method_name == "size":
original_methods["shape"] = torch.Tensor.shape
setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name)))
return cache_names, original_methods
def _generate_dummy_input(
self, model: PreTrainedModel, input_name: str, shape: List[int]
) -> Dict[str, torch.Tensor]:
......@@ -365,6 +471,7 @@ class HFTracer(Tracer):
inputs_dict = {}
if input_name in ["labels", "start_positions", "end_positions"]:
batch_size = shape[0]
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
......@@ -374,8 +481,31 @@ class HFTracer(Tracer):
]:
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
raise ValueError(
"Could not retrieve the problem type for the sequence classification task, please set "
'model.config.problem_type to one of the following values: "regression", '
'"single_label_classification", or "multi_label_classification".'
)
if model.config.problem_type == "regression":
labels_shape = (batch_size, model.config.num_labels)
labels_dtype = torch.float32
elif model.config.problem_type == "single_label_classification":
labels_shape = (batch_size,)
labels_dtype = torch.long
elif model.config.problem_type == "multi_label_classification":
labels_shape = (batch_size, model.config.num_labels)
labels_dtype = torch.float32
else:
raise ValueError(
'Expected model.config.problem_type to be either: "regression", "single_label_classification"'
f', or "multi_label_classification", but "{model.config.problem_type}" was provided.'
)
inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
elif model_class in [
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
]:
......@@ -400,60 +530,82 @@ class HFTracer(Tracer):
return inputs_dict
def record(self, model: PreTrainedModel, input_names: List[str], method_names: Optional[Iterable[str]] = None):
"""
Records torch.Tensor method outputs (specified by method_names) that will then be used during symbolic tracing.
"""
if method_names is None:
method_names = self._DEFAULT_METHODS_TO_RECORD
# Creating a random input shape to generate dummy inputs.
batch_size = _generate_random_int()
sequence_length = _generate_random_int()
shape = [batch_size, sequence_length]
if model.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
num_choices = _generate_random_int(low=2, high=5)
shape.insert(1, num_choices)
inputs = {}
for input_name in input_names:
inputs.update(self._generate_dummy_input(model, input_name, shape))
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
if kind == "placeholder" and target in self.meta_args:
rv.install_metadata(self.meta_args[target])
return rv
if target in self.orig_fns:
# NOTE: tensor constructors in PyTorch define the `device` argument as
# *kwargs-only*. That is why this works. If you add methods to
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
# this will break and you will likely see issues where we cannot infer
# the size of the output.
if "device" in kwargs:
kwargs["device"] = "meta"
try:
args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas)
kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas)
if kind == "call_function":
meta_target = _MANUAL_META_OVERRIDES.get(target, target)
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
meta_target = _MANUAL_META_OVERRIDES.get(method, method)
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == "call_module":
if not hasattr(self, "orig_forward"):
raise AttributeError(f"{self} does not have an attribute called orig_forward")
self._disable_module_getattr = True
try:
mod = self.root.get_submodule(target)
mod_type = type(mod)
if mod_type in _MANUAL_META_OVERRIDES:
meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
else:
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
finally:
self._disable_module_getattr = False
elif kind == "get_attr":
self._disable_module_getattr = True
try:
attr_itr = self.root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.Tensor):
meta_out = attr_itr.to(device="meta")
else:
meta_out = attr_itr
finally:
self._disable_module_getattr = False
else:
return rv
cache_names, original_methods = self._monkey_patch_tensor_methods_for_model_recording(model, method_names)
self.original_methods = original_methods
if not isinstance(rv, Proxy):
raise ValueError("Don't support composite output yet")
rv.install_metadata(meta_out)
except Exception as e:
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
model(**inputs)
return rv
_reset_tensor_methods(original_methods)
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, "_disable_module_getattr", False):
return attr_val
else:
return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
self.recorded_methods = {
method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(model, cache_name)
}
def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward
return super().call_module(m, forward, args, kwargs)
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if isinstance(attr_val, torch.nn.Parameter):
for n, p in self.root.named_parameters():
if attr_val is p:
if n not in parameter_proxy_cache:
parameter_proxy_cache[n] = self.create_proxy("get_attr", n, (), {})
return parameter_proxy_cache[n]
# TODO: condition this on wether dynamic axes were requested.
if isinstance(attr_val, torch.Tensor):
for n, p in self.root.named_buffers():
if attr_val is p:
if n not in parameter_proxy_cache:
parameter_proxy_cache[n] = self.create_proxy("get_attr", n, (), {})
return parameter_proxy_cache[n]
return attr_val
def proxy(self, node: Node):
p = HFProxy(node, self)
if self.recorded_methods:
for method_name, cache_name in self.recorded_methods.items():
return_proxy = self._DEFAULT_METHODS_TO_RECORD[method_name]
_create_recorded_proxy_method(p, method_name, cache_name, return_proxy)
return p
def proxy(self, node):
return HFProxy(node, self)
def trace(
self,
......@@ -461,25 +613,42 @@ class HFTracer(Tracer):
concrete_args: Optional[Dict[str, Any]] = None,
method_names: Optional[Iterable[str]] = None,
) -> Graph:
if concrete_args is None:
concrete_args = {}
sig = inspect.signature(root.forward)
input_names = sig.parameters.keys() - concrete_args.keys()
self.record(root, input_names, method_names=method_names)
# Creating a random input shape to generate dummy inputs.
batch_size = _generate_random_int()
sequence_length = _generate_random_int()
shape = [batch_size, sequence_length]
# TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer.
autowrap_functions = [patched for (_, _, patched) in self._leaf_functions_register.values()]
self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions]))
if root.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
num_choices = _generate_random_int(low=2, high=5)
shape.insert(1, num_choices)
self._patch_leaf_functions_for_root(root)
inputs = {}
for input_name in input_names:
inputs.update(self._generate_dummy_input(root, input_name, shape))
self.graph = super().trace(root, concrete_args=concrete_args)
concrete_metas = {input_name: input_.to("meta") for input_name, input_ in inputs.items()}
self.meta_args = concrete_metas
self.patched_torch_methods = {
target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
}
self.orig_fns = set()
self._patch_leaf_functions_for_root(root, restore=True)
for name, (wrapper, orig) in self.patched_torch_methods.items():
setattr(torch, name, wrapper)
self.orig_fns.add(orig)
_reset_tensor_methods(self.original_methods)
try:
self.graph = super().trace(root, concrete_args=concrete_args)
finally:
for name, (_, orig) in self.patched_torch_methods.items():
setattr(torch, name, orig)
# TODO: keep this until necessary.
# This is necessary because concrete args are added as input to the traced module since
......@@ -496,18 +665,35 @@ class HFTracer(Tracer):
return self.graph
def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
"""
Whether the module was instantiated with Proxies. If that is the case, such module cannot be a leaf module
because its attributes are input-dependent.
"""
return any(isinstance(attr, Proxy) for attr in mod.__dict__.values())
def _insert_module_as_submodule(self, mod: nn.Module) -> str:
"""
Helper method which tries to insert a module that was not declared as submodule.
"""
# If one of the module attributes is a Proxy, it means that its instantiation is input-dependent.
# It is not possible to insert such modules, those should be traced through.
if self._stateless_mod_instanciation_depends_on_proxies(mod):
return ""
idx = 0
mod_name = mod.__class__.__name__.lower()
path = f"{mod_name}_{idx}"
already_inserted = False
while hasattr(self.root, path):
if getattr(self.root, path) is mod:
already_inserted = True
break
path = f"{mod_name}_{idx}"
idx += 1
self.root.add_module(path, mod)
# No need to add multiple instances of the same module.
if not already_inserted:
self.root.add_module(path, mod)
return path
def path_of_module(self, mod: nn.Module) -> str:
......@@ -519,37 +705,18 @@ class HFTracer(Tracer):
Args:
mod (str): The `Module` to retrieve the qualified name for.
"""
# Prefer the O(1) algorithm
if hasattr(self, "submodule_paths") and self.submodule_paths:
path = self.submodule_paths.get(mod)
if path is None:
try:
return super().path_of_module(mod)
except NameError as e:
if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
path = self._insert_module_as_submodule(mod)
if path is None:
raise NameError(f"Module named {mod._get_name()} is not installed as a submodule")
self.prev_module = path
return path
return path
raise e
# O(N^2) fallback in the case that we didn't store the submodule
# paths.
else:
for n, p in self.root.named_modules():
if mod is p:
self.prev_module = n
return n
path = self._insert_module_as_submodule(mod)
if path is None:
raise NameError(f"Module {mod._get_name()} is not installed as a submodule")
self.prev_module = path
return path
def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
is_loss_module = m.__module__.startswith("torch.nn.modules.loss")
return (not is_loss_module) and super().is_leaf_module(m, module_qualified_name)
def create_arg(self, a: Any) -> Argument:
if isinstance(a, range):
return super().create_arg(list(a))
return super().create_arg(a)
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
return (not self._stateless_mod_instanciation_depends_on_proxies(m)) and super().is_leaf_module(
m, module_qualified_name
)
def symbolic_trace(
......@@ -594,4 +761,12 @@ def symbolic_trace(
traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph)
# Copy all the original attributes to the traced GraphModule.
regular_module_attributes = dir(nn.Module())
for name in dir(model):
attr = getattr(model, name)
if name.startswith("_") or name in regular_module_attributes:
continue
setattr(traced, name, deepcopy(attr))
return traced
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment