Unverified Commit 2c2a2169 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

Fx with meta (#16836)

* Add meta proxy

* Uses meta data to trace data dependent control-flow

* Remove commented class

* Handles torch creating functions

* Added type annotation to fix tracing

* Tracing works for everything but T5 and GPT-J

* Almost all previously supported models pass

* All architectures can be traced except T5

* Intermediate commit to have a trace of the comparison operators for HFProxy

* Everything works, except loss computation

* Everything works

* Removed unused import

* Overriden methods do not use underlying ops (linear and torch.matmul), and model attributes are copied to the traced version

* Fix torch_matmul_override

* Change attributes reference to deepcopy

* Remove breakpoint and add torch_index_override

* Small fix

* Fix typo

* Replace asserts by explicit exceptions
parent ff846e9b
...@@ -850,7 +850,7 @@ class ElectraModel(ElectraPreTrainedModel): ...@@ -850,7 +850,7 @@ class ElectraModel(ElectraPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithCrossAttentions]: ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -985,7 +985,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel): ...@@ -985,7 +985,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]: ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -1075,7 +1075,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel): ...@@ -1075,7 +1075,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, ElectraForPreTrainingOutput]: ) -> Union[Tuple[torch.Tensor], ElectraForPreTrainingOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see `input_ids` docstring) Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see `input_ids` docstring)
...@@ -1197,7 +1197,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel): ...@@ -1197,7 +1197,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]: ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
...@@ -1283,7 +1283,7 @@ class ElectraForTokenClassification(ElectraPreTrainedModel): ...@@ -1283,7 +1283,7 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]: ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
...@@ -1368,7 +1368,7 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel): ...@@ -1368,7 +1368,7 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]: ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
...@@ -1469,7 +1469,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel): ...@@ -1469,7 +1469,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, MultipleChoiceModelOutput]: ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
...@@ -1564,7 +1564,7 @@ class ElectraForCausalLM(ElectraPreTrainedModel): ...@@ -1564,7 +1564,7 @@ class ElectraForCausalLM(ElectraPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
......
...@@ -508,7 +508,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): ...@@ -508,7 +508,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -727,7 +727,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel): ...@@ -727,7 +727,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
...@@ -842,7 +842,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel): ...@@ -842,7 +842,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]: ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......
...@@ -919,7 +919,7 @@ class RobertaForCausalLM(RobertaPreTrainedModel): ...@@ -919,7 +919,7 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
...@@ -1080,7 +1080,7 @@ class RobertaForMaskedLM(RobertaPreTrainedModel): ...@@ -1080,7 +1080,7 @@ class RobertaForMaskedLM(RobertaPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]: ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
...@@ -1193,7 +1193,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel): ...@@ -1193,7 +1193,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]: ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -1290,7 +1290,7 @@ class RobertaForMultipleChoice(RobertaPreTrainedModel): ...@@ -1290,7 +1290,7 @@ class RobertaForMultipleChoice(RobertaPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, MultipleChoiceModelOutput]: ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
...@@ -1390,7 +1390,7 @@ class RobertaForTokenClassification(RobertaPreTrainedModel): ...@@ -1390,7 +1390,7 @@ class RobertaForTokenClassification(RobertaPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]: ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
...@@ -1496,7 +1496,7 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel): ...@@ -1496,7 +1496,7 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]: ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
...@@ -13,18 +13,19 @@ ...@@ -13,18 +13,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import builtins
import functools import functools
import inspect import inspect
import math import math
import random import random
from types import ModuleType import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
import torch import torch
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.fx import Graph, GraphModule, Node, Proxy, Tracer from torch.fx import Graph, GraphModule, Proxy, Tracer
from torch.fx.node import Argument
from .. import ( from .. import (
CONFIG_MAPPING, CONFIG_MAPPING,
...@@ -121,107 +122,313 @@ _SUPPORTED_MODELS = tuple( ...@@ -121,107 +122,313 @@ _SUPPORTED_MODELS = tuple(
) )
def embedding_override(self, input):
return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
def torch_nn_layernorm_override(self, input):
return input
def torch_nn_linear_override(self, input):
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
def torch_relu_override(x):
return x
def torch_nn_relu_override(self, x):
return x
def torch_nn_functional_relu_override(x, inplace=False):
if not inplace:
raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
return x
def torch_where_override(condition, x, y):
# torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
def torch_abs_override(input, *, out=None):
if out is None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
return input
def torch_arange_override(*args, **kwargs):
n = len(args)
step = 1
if n == 1:
start = 0
end = args[0]
elif n == 2:
start, end = args
else:
start, end, step = args
step = kwargs.get("step", step)
dtype = kwargs.get("dtype")
return torch.empty((end - start) // step, dtype=dtype, device="meta")
def torch_cat_override(tensors, dim=None, axis=None, *, out=None):
if dim is None and axis is None:
dim = 0
if dim is None and axis is not None:
dim = axis
if dim < 0:
dim = tensors[0].dim() + dim
shapes = [t.shape for t in tensors]
shape = list(shapes[0])
concatenated_dim = sum(shape[dim] for shape in shapes)
final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :]
return torch.empty(final_shape, device="meta")
def torch_stack_override(tensors, dim=None, axis=None, *, out=None):
if dim is None and axis is None:
dim = 0
if dim is None and axis is not None:
dim = axis
if dim < 0:
dim = tensors[0].dim() + 1 + dim
shape = list(tensors[0].shape)
shape.insert(dim, len(tensors))
return torch.empty(shape, device="meta")
def torch_add_override(input, other, *, alpha=1, out=None):
if not isinstance(input, torch.Tensor):
return torch.empty_like(other, device="meta")
if not isinstance(other, torch.Tensor):
return torch.empty_like(input, device="meta")
max_length = max(input.dim(), other.dim())
input_shape = list(input.shape) + [1] * (max_length - input.dim())
other_shape = list(other.shape) + [1] * (max_length - other.dim())
shape = []
for i in range(max_length):
shape.append(max(input_shape[i], other_shape[i]))
return torch.empty(shape, device="meta")
def torch_mul_override(input, other, *, out=None):
return torch_add_override(input, other, out=out)
def torch_tensor_mul_override(self, other):
return torch_mul_override(self, other)
def torch_matmul_override(input, other, *, out=None):
d1 = input.dim()
d2 = other.dim()
shape = None
if d1 == 1 and d2 == 1:
shape = None
elif d1 == 2 and d2 == 2:
shape = (input.size(0), other.size(1))
elif d1 == 1 and d2 == 2:
shape = (other.size(1),)
elif d1 == 2 and d1 == 1:
shape = (input.size(0),)
else:
max_length = max(input.dim(), other.dim())
shape1 = list(input.shape)
shape2 = list(other.shape)
if d1 == 1:
shape1 = [1] + shape1
if d2 == 1:
shape2.append(1)
shape1 = [-1] * (max_length - d1) + list(input.shape)
shape2 = [-1] * (max_length - d2) + list(other.shape)
shape = []
for i in range(max_length):
shape.append(max(shape1[i], shape2[i]))
shape[-2] = shape1[-2]
shape[-1] = shape2[-1]
if d1 == 1:
shape.pop(-2)
if d2 == 1:
shape.pop(-1)
if shape is None:
return torch.tensor(0.0, device="meta")
return torch.empty(*shape, device="meta")
def torch_tensor_repeat_override(self, *sizes):
shape = list(self.shape)
for i, x in enumerate(sizes):
shape[i] *= x
return torch.empty(shape, device="meta")
def torch_index_select(input, dim, index, *, out=None):
shape = list(input.shape)
shape[dim] = len(index)
return torch.empty(*shape, device="meta")
def torch_tensor_index_select(self, dim, index):
return torch_tensor_index_select(self, dim, index)
def torch_nn_mseloss(self, input, target):
if self.reduction == "none":
shape = target.shape
else:
shape = (1,)
return torch.empty(shape, device="meta")
def torch_nn_crossentropyloss(self, input, target):
if self.reduction == "none":
shape = target.shape
else:
shape = (1,)
return torch.empty(shape, device="meta")
def torch_nn_bcewithlogitsloss(self, input, target):
if self.reduction == "none":
shape = target.shape
else:
shape = (1,)
return torch.empty(shape, device="meta")
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.nn.Embedding: embedding_override,
torch.nn.LayerNorm: torch_nn_layernorm_override,
torch.nn.Linear: torch_nn_linear_override,
torch.relu: torch_relu_override,
torch.nn.functional.relu: torch_nn_functional_relu_override,
torch.nn.ReLU: torch_nn_relu_override,
torch.where: torch_where_override,
torch.abs: torch_abs_override,
torch.arange: torch_arange_override,
torch.cat: torch_cat_override,
torch.stack: torch_stack_override,
torch.add: torch_add_override,
torch.mul: torch_mul_override,
torch.Tensor.mul: torch_tensor_mul_override,
torch.matmul: torch_matmul_override,
torch.Tensor.repeat: torch_tensor_repeat_override,
# TODO: those might not be needed.
# torch.index_select: torch_index_select,
# torch.Tensor.index_select: torch_tensor_index_select,
torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
}
class HFProxy(Proxy): class HFProxy(Proxy):
""" """
Proxy that is able to provide the proper ranks, shapes and boolean values during symbolic tracing by implementing Proxy that uses metadata to handle data-dependent control-flow.
the dim, size and __bool__ methods. It can be easily extended by either adding new methods or extending the
existing ones.
""" """
def __init__(self, node: Node, tracer: Optional[Tracer] = None): def install_metadata(self, metadata):
super().__init__(node, tracer=tracer) self._metadata = metadata
if hasattr(self, "tracer") and self.tracer is not None:
self.device = self.tracer.root.device
self.dtype = next(self.tracer.root.parameters()).dtype
self.cache = None
@property @property
def shape(self): def shape(self):
return self.size() return self.tracer.create_proxy("call_method", "size", (self,), {})
def __setitem__(self, key, value):
pass
def __contains__(self, key):
return False
def __eq__(self, other): @property
if self.cache is not None: def dtype(self):
return self.cache == other return self.tracer.root.dtype
elif isinstance(other, HFProxy): if hasattr(self, "_metadata") and self._metadata is not None:
return True return self._metadata.dtype
else: return self.tracer.create_proxy("call_function", builtins.getattr, (self, "dtype"), {})
return super().__eq__(other)
def __ne__(self, other): @property
return not self == other def device(self):
# Hack so we can track when devices are used. During meta-tensor propagation,
# replace these values with a constant 'meta'
return MetaDeviceAttribute(self, "device")
def __len__(self): def __len__(self):
if self.cache is not None: if hasattr(self, "_metadata") and self._metadata is not None:
if isinstance(self.cache, int): return len(self._metadata)
return self.cache return super().__len__()
elif isinstance(self.cache, (torch.Size, list, tuple)):
return len(self.cache) def __bool__(self):
else: if hasattr(self, "_metadata") and self._metadata is not None:
return super().__len__(self) return self._metadata
return super().__len__(self) return super().__bool__()
def __getattr__(self, k):
if k == "_metadata":
return self.__getattribute__(k)
# note: not added to the graph yet, if this is a method call
# we peephole optimize to the method invocation
return HFAttribute(self, k)
def __torch_function__(self, orig_method, types, args=None, kwargs=None): def __contains__(self, key):
proxy = super().__torch_function__(orig_method, types, args=args, kwargs=kwargs) # To handle cases such as :
proxy.cache = self.cache # `"some_key" in kwargs`
return proxy if self.node.op == "placeholder":
return False
return super().__contains__(key)
def _function_to_leaf(func: Callable[..., Any]) -> Callable[..., Any]: class HFAttribute(HFProxy):
"""Wrapper that marks func as a leaf function, meaning that it will not be traced through by HFTracer.""" def __init__(self, root, attr: str):
self.root = root
self.attr = attr
self.tracer = root.tracer
self._node = None
@functools.wraps(func) @property
def wrapper(*args, **kwargs): def node(self):
return func(*args, **kwargs) # the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
return self._node
return wrapper def __call__(self, *args, **kwargs):
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
def _function_leaf_getter(func_name: str, mapping: Dict[str, Callable[..., Any]]) -> Callable[..., Any]: class MetaDeviceAttribute(HFAttribute):
@functools.wraps(mapping[func_name]) pass
def wrapper(*args, **kwargs):
return mapping[func_name](*args, **kwargs)
return wrapper
def _proxies_to_metas(v):
"""Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
if isinstance(v, MetaDeviceAttribute):
return "meta"
if isinstance(v, torch.fx.Proxy):
if not (isinstance(v, HFProxy) and hasattr(v, "_metadata")):
raise RuntimeError(f"No metadata was found for {v}")
return v._metadata
return v
def _create_recorded_proxy_method(proxy: HFProxy, method_name: str, cache_name: str, return_proxy: bool):
"""
Helper function that sets a recorded torch.Tensor method as a HFProxy method that will use the recorded values
during symbolic tracing.
"""
original_method = getattr(torch.Tensor, method_name) def _gen_constructor_wrapper(target):
@functools.wraps(target)
@functools.wraps(original_method) def wrapper(*args, **kwargs):
def method(*args, **kwargs): proxy = None
cache = getattr(args[0].tracer.root, cache_name)
res = cache.pop(0)
if return_proxy:
proxy = args[0].__torch_function__(
original_method,
None,
args=args,
kwargs=kwargs,
)
proxy.cache = res
return proxy
return res
method.__name__ = method_name def check_has_proxy(v):
bound_method = method.__get__(proxy, proxy.__class__) if isinstance(v, Proxy):
setattr(proxy, method_name, bound_method) nonlocal proxy
proxy = v
torch.fx.node.map_aggregate(args, check_has_proxy)
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
def _reset_tensor_methods(original_methods: Dict[str, Callable[..., Any]]): if proxy is not None:
"""Helper function that resets the monkey patched torch.Tensor methods to their original values.""" return proxy.tracer.create_proxy("call_function", target, args, kwargs)
for name, method in original_methods.items(): else:
setattr(torch.Tensor, name, method) return target(*args, **kwargs)
return wrapper, target
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None): def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
...@@ -239,27 +446,11 @@ class HFTracer(Tracer): ...@@ -239,27 +446,11 @@ class HFTracer(Tracer):
regular PyTorch torch.fx.Proxy. regular PyTorch torch.fx.Proxy.
""" """
_DEFAULT_METHODS_TO_RECORD = {"__bool__": False, "size": True, "dim": False} allow_insert_stateless_mods: bool = True
from transformers import modeling_utils _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"]
_FUNCTIONS_TO_AUTOWRAP = {
torch: {"arange", "zeros", "ones", "full_like", "eye"},
modeling_utils.ModuleUtilsMixin: {"create_extended_attention_mask_for_decoder"},
}
def __init__(self, autowrap_modules=(math,), autowrap_functions=(), enable_cpatching=False): def __init__(self, autowrap_modules=(math,), autowrap_functions=(), enable_cpatching=False):
# Loading the leaf functions register
self._leaf_functions_register = {}
for module, names in self._FUNCTIONS_TO_AUTOWRAP.items():
for name in names:
self._register_leaf_function(module, name)
# TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer.
# autowrap_functions = autowrap_functions + tuple(
# patched for (_, _, patched) in self._leaf_functions_register.values()
# )
super().__init__( super().__init__(
autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions, enable_cpatching=enable_cpatching autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions, enable_cpatching=enable_cpatching
) )
...@@ -271,91 +462,6 @@ class HFTracer(Tracer): ...@@ -271,91 +462,6 @@ class HFTracer(Tracer):
f"{TORCH_FX_REQUIRED_VERSION} is supported." f"{TORCH_FX_REQUIRED_VERSION} is supported."
) )
self.prev_module = None
self.recorded_methods = None
def _register_leaf_function(self, module: ModuleType, name: str):
"""Registers the function called name in module as a leaf function."""
orig_func = getattr(module, name)
patched_func = _function_to_leaf(orig_func)
patched_func.__module__ = __name__
self._leaf_functions_register[name] = (module, orig_func, patched_func)
def _patch_leaf_functions_for_root(self, root: PreTrainedModel, restore: bool = False):
"""Patches leaf functions specifically for root."""
for name in self._leaf_functions_register:
module, orig_func, patched_func = self._leaf_functions_register[name]
if restore:
root.__class__.forward.__globals__.pop(name)
setattr(module, name, orig_func)
else:
root.__class__.forward.__globals__[name] = patched_func
leaf_getter = _function_leaf_getter(name, root.__class__.forward.__globals__)
leaf_getter.__module__ = __name__
setattr(module, name, leaf_getter)
def _method_is_called_in_leaf_module(self, module_ids: List[int]) -> bool:
"""
Finds out if the method (that is being recorded) is called inside a leaf module, this allows to not record
outputs that will not be encountered by the tracer.
"""
currentframe = inspect.currentframe()
while currentframe:
if currentframe is None:
return False
module = currentframe.f_locals.get("self", None)
if id(module) in module_ids and self.is_leaf_module(module, "Not used anyway"):
return True
currentframe = currentframe.f_back
return False
def _wrap_method_for_model_recording(
self, model: PreTrainedModel, method_name: str, cache_name: str, module_ids: List[int]
):
"""Helper function that wraps a torch.Tensor method to record its outputs during forward pass."""
method = getattr(torch.Tensor, method_name)
@functools.wraps(method)
def wrapped(*args, **kwargs):
if self._method_is_called_in_leaf_module(module_ids):
return method(*args, **kwargs)
if not hasattr(model, cache_name):
setattr(model, cache_name, [])
cache = getattr(model, cache_name)
res = method(*args, **kwargs)
cache.append(res)
return res
return wrapped
def _monkey_patch_tensor_methods_for_model_recording(self, model: PreTrainedModel, method_names: Iterable[str]):
"""
Helper function that patches torch.Tensor methods (specified by the method_names list) to record model
inference before symbolic tracing.
"""
cache_names = {}
original_methods = {}
module_ids = set(id(mod) for mod in model.modules())
for method_name in method_names:
cache_name = f"cache_{method_name}"
cache_names[method_name] = cache_name
if not hasattr(torch.Tensor, method_name):
logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.")
continue
original_methods[method_name] = getattr(torch.Tensor, method_name)
setattr(
torch.Tensor,
method_name,
self._wrap_method_for_model_recording(model, method_name, cache_name, module_ids),
)
if method_name == "size":
original_methods["shape"] = torch.Tensor.shape
setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name)))
return cache_names, original_methods
def _generate_dummy_input( def _generate_dummy_input(
self, model: PreTrainedModel, input_name: str, shape: List[int] self, model: PreTrainedModel, input_name: str, shape: List[int]
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
...@@ -365,6 +471,7 @@ class HFTracer(Tracer): ...@@ -365,6 +471,7 @@ class HFTracer(Tracer):
inputs_dict = {} inputs_dict = {}
if input_name in ["labels", "start_positions", "end_positions"]: if input_name in ["labels", "start_positions", "end_positions"]:
batch_size = shape[0] batch_size = shape[0]
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
...@@ -374,8 +481,31 @@ class HFTracer(Tracer): ...@@ -374,8 +481,31 @@ class HFTracer(Tracer):
]: ]:
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
raise ValueError(
"Could not retrieve the problem type for the sequence classification task, please set "
'model.config.problem_type to one of the following values: "regression", '
'"single_label_classification", or "multi_label_classification".'
)
if model.config.problem_type == "regression":
labels_shape = (batch_size, model.config.num_labels)
labels_dtype = torch.float32
elif model.config.problem_type == "single_label_classification":
labels_shape = (batch_size,)
labels_dtype = torch.long
elif model.config.problem_type == "multi_label_classification":
labels_shape = (batch_size, model.config.num_labels)
labels_dtype = torch.float32
else:
raise ValueError(
'Expected model.config.problem_type to be either: "regression", "single_label_classification"'
f', or "multi_label_classification", but "{model.config.problem_type}" was provided.'
)
inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
elif model_class in [ elif model_class in [
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING), *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING), *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
]: ]:
...@@ -400,60 +530,82 @@ class HFTracer(Tracer): ...@@ -400,60 +530,82 @@ class HFTracer(Tracer):
return inputs_dict return inputs_dict
def record(self, model: PreTrainedModel, input_names: List[str], method_names: Optional[Iterable[str]] = None): def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
""" rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
Records torch.Tensor method outputs (specified by method_names) that will then be used during symbolic tracing.
""" if kind == "placeholder" and target in self.meta_args:
if method_names is None: rv.install_metadata(self.meta_args[target])
method_names = self._DEFAULT_METHODS_TO_RECORD return rv
# Creating a random input shape to generate dummy inputs. if target in self.orig_fns:
batch_size = _generate_random_int() # NOTE: tensor constructors in PyTorch define the `device` argument as
sequence_length = _generate_random_int() # *kwargs-only*. That is why this works. If you add methods to
shape = [batch_size, sequence_length] # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
# this will break and you will likely see issues where we cannot infer
if model.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): # the size of the output.
num_choices = _generate_random_int(low=2, high=5) if "device" in kwargs:
shape.insert(1, num_choices) kwargs["device"] = "meta"
inputs = {} try:
for input_name in input_names: args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas)
inputs.update(self._generate_dummy_input(model, input_name, shape)) kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas)
if kind == "call_function":
meta_target = _MANUAL_META_OVERRIDES.get(target, target)
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
meta_target = _MANUAL_META_OVERRIDES.get(method, method)
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == "call_module":
if not hasattr(self, "orig_forward"):
raise AttributeError(f"{self} does not have an attribute called orig_forward")
self._disable_module_getattr = True
try:
mod = self.root.get_submodule(target)
mod_type = type(mod)
if mod_type in _MANUAL_META_OVERRIDES:
meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
else:
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
finally:
self._disable_module_getattr = False
elif kind == "get_attr":
self._disable_module_getattr = True
try:
attr_itr = self.root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.Tensor):
meta_out = attr_itr.to(device="meta")
else:
meta_out = attr_itr
finally:
self._disable_module_getattr = False
else:
return rv
cache_names, original_methods = self._monkey_patch_tensor_methods_for_model_recording(model, method_names) if not isinstance(rv, Proxy):
self.original_methods = original_methods raise ValueError("Don't support composite output yet")
rv.install_metadata(meta_out)
except Exception as e:
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
model(**inputs) return rv
_reset_tensor_methods(original_methods) def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, "_disable_module_getattr", False):
return attr_val
else:
return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
self.recorded_methods = { def call_module(self, m, forward, args, kwargs):
method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(model, cache_name) self.orig_forward = forward
} return super().call_module(m, forward, args, kwargs)
def _module_getattr(self, attr, attr_val, parameter_proxy_cache): def proxy(self, node):
if isinstance(attr_val, torch.nn.Parameter): return HFProxy(node, self)
for n, p in self.root.named_parameters():
if attr_val is p:
if n not in parameter_proxy_cache:
parameter_proxy_cache[n] = self.create_proxy("get_attr", n, (), {})
return parameter_proxy_cache[n]
# TODO: condition this on wether dynamic axes were requested.
if isinstance(attr_val, torch.Tensor):
for n, p in self.root.named_buffers():
if attr_val is p:
if n not in parameter_proxy_cache:
parameter_proxy_cache[n] = self.create_proxy("get_attr", n, (), {})
return parameter_proxy_cache[n]
return attr_val
def proxy(self, node: Node):
p = HFProxy(node, self)
if self.recorded_methods:
for method_name, cache_name in self.recorded_methods.items():
return_proxy = self._DEFAULT_METHODS_TO_RECORD[method_name]
_create_recorded_proxy_method(p, method_name, cache_name, return_proxy)
return p
def trace( def trace(
self, self,
...@@ -461,25 +613,42 @@ class HFTracer(Tracer): ...@@ -461,25 +613,42 @@ class HFTracer(Tracer):
concrete_args: Optional[Dict[str, Any]] = None, concrete_args: Optional[Dict[str, Any]] = None,
method_names: Optional[Iterable[str]] = None, method_names: Optional[Iterable[str]] = None,
) -> Graph: ) -> Graph:
if concrete_args is None: if concrete_args is None:
concrete_args = {} concrete_args = {}
sig = inspect.signature(root.forward) sig = inspect.signature(root.forward)
input_names = sig.parameters.keys() - concrete_args.keys() input_names = sig.parameters.keys() - concrete_args.keys()
self.record(root, input_names, method_names=method_names) # Creating a random input shape to generate dummy inputs.
batch_size = _generate_random_int()
sequence_length = _generate_random_int()
shape = [batch_size, sequence_length]
# TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer. if root.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
autowrap_functions = [patched for (_, _, patched) in self._leaf_functions_register.values()] num_choices = _generate_random_int(low=2, high=5)
self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions])) shape.insert(1, num_choices)
self._patch_leaf_functions_for_root(root) inputs = {}
for input_name in input_names:
inputs.update(self._generate_dummy_input(root, input_name, shape))
self.graph = super().trace(root, concrete_args=concrete_args) concrete_metas = {input_name: input_.to("meta") for input_name, input_ in inputs.items()}
self.meta_args = concrete_metas
self.patched_torch_methods = {
target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
}
self.orig_fns = set()
self._patch_leaf_functions_for_root(root, restore=True) for name, (wrapper, orig) in self.patched_torch_methods.items():
setattr(torch, name, wrapper)
self.orig_fns.add(orig)
_reset_tensor_methods(self.original_methods) try:
self.graph = super().trace(root, concrete_args=concrete_args)
finally:
for name, (_, orig) in self.patched_torch_methods.items():
setattr(torch, name, orig)
# TODO: keep this until necessary. # TODO: keep this until necessary.
# This is necessary because concrete args are added as input to the traced module since # This is necessary because concrete args are added as input to the traced module since
...@@ -496,18 +665,35 @@ class HFTracer(Tracer): ...@@ -496,18 +665,35 @@ class HFTracer(Tracer):
return self.graph return self.graph
def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
"""
Whether the module was instantiated with Proxies. If that is the case, such module cannot be a leaf module
because its attributes are input-dependent.
"""
return any(isinstance(attr, Proxy) for attr in mod.__dict__.values())
def _insert_module_as_submodule(self, mod: nn.Module) -> str: def _insert_module_as_submodule(self, mod: nn.Module) -> str:
""" """
Helper method which tries to insert a module that was not declared as submodule. Helper method which tries to insert a module that was not declared as submodule.
""" """
# If one of the module attributes is a Proxy, it means that its instantiation is input-dependent.
# It is not possible to insert such modules, those should be traced through.
if self._stateless_mod_instanciation_depends_on_proxies(mod):
return ""
idx = 0 idx = 0
mod_name = mod.__class__.__name__.lower() mod_name = mod.__class__.__name__.lower()
path = f"{mod_name}_{idx}" path = f"{mod_name}_{idx}"
already_inserted = False
while hasattr(self.root, path): while hasattr(self.root, path):
if getattr(self.root, path) is mod:
already_inserted = True
break
path = f"{mod_name}_{idx}" path = f"{mod_name}_{idx}"
idx += 1 idx += 1
self.root.add_module(path, mod) # No need to add multiple instances of the same module.
if not already_inserted:
self.root.add_module(path, mod)
return path return path
def path_of_module(self, mod: nn.Module) -> str: def path_of_module(self, mod: nn.Module) -> str:
...@@ -519,37 +705,18 @@ class HFTracer(Tracer): ...@@ -519,37 +705,18 @@ class HFTracer(Tracer):
Args: Args:
mod (str): The `Module` to retrieve the qualified name for. mod (str): The `Module` to retrieve the qualified name for.
""" """
# Prefer the O(1) algorithm try:
if hasattr(self, "submodule_paths") and self.submodule_paths: return super().path_of_module(mod)
path = self.submodule_paths.get(mod) except NameError as e:
if path is None: if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
path = self._insert_module_as_submodule(mod) path = self._insert_module_as_submodule(mod)
if path is None: return path
raise NameError(f"Module named {mod._get_name()} is not installed as a submodule") raise e
self.prev_module = path
return path
# O(N^2) fallback in the case that we didn't store the submodule def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
# paths. return (not self._stateless_mod_instanciation_depends_on_proxies(m)) and super().is_leaf_module(
else: m, module_qualified_name
for n, p in self.root.named_modules(): )
if mod is p:
self.prev_module = n
return n
path = self._insert_module_as_submodule(mod)
if path is None:
raise NameError(f"Module {mod._get_name()} is not installed as a submodule")
self.prev_module = path
return path
def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
is_loss_module = m.__module__.startswith("torch.nn.modules.loss")
return (not is_loss_module) and super().is_leaf_module(m, module_qualified_name)
def create_arg(self, a: Any) -> Argument:
if isinstance(a, range):
return super().create_arg(list(a))
return super().create_arg(a)
def symbolic_trace( def symbolic_trace(
...@@ -594,4 +761,12 @@ def symbolic_trace( ...@@ -594,4 +761,12 @@ def symbolic_trace(
traced_graph = tracer.trace(model, concrete_args=concrete_args) traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph) traced = torch.fx.GraphModule(model, traced_graph)
# Copy all the original attributes to the traced GraphModule.
regular_module_attributes = dir(nn.Module())
for name in dir(model):
attr = getattr(model, name)
if name.startswith("_") or name in regular_module_attributes:
continue
setattr(traced, name, deepcopy(attr))
return traced return traced
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment