Unverified Commit 2e7e4280 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

Traced models serialization and torchscripting fix (#17206)

* Fix torch.jit.script and pickling issues

* Fix get_attr issues

* Fix import in function

* Fix GPT-J and T5 tracing for torch=1.11

* Gate graph surgery on torch version

* Modeling minor changes to enable TorchScripting

* Model serialization / deserialization test

* Remove _assert_is_none users
parent 1cd01b0a
...@@ -187,7 +187,7 @@ class DecisionTransformerGPT2Attention(nn.Module): ...@@ -187,7 +187,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
if not self.is_cross_attention: if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask # if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
if attention_mask is not None: if attention_mask is not None:
......
...@@ -211,7 +211,7 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -211,7 +211,7 @@ class MultiHeadSelfAttention(nn.Module):
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(mask, -float("inf")) # (bs, n_heads, q_length, k_length) scores = scores.masked_fill(mask, torch.tensor(-float("inf"))) # (bs, n_heads, q_length, k_length)
weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length) weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length) weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
......
...@@ -198,7 +198,7 @@ class GPT2Attention(nn.Module): ...@@ -198,7 +198,7 @@ class GPT2Attention(nn.Module):
if not self.is_cross_attention: if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask # if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
if attention_mask is not None: if attention_mask is not None:
...@@ -1410,7 +1410,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel): ...@@ -1410,7 +1410,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
"unexpected if using padding tokens in conjunction with `inputs_embeds.`" "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
) )
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None loss = None
if labels is not None: if labels is not None:
......
...@@ -147,8 +147,8 @@ class GPTNeoSelfAttention(nn.Module): ...@@ -147,8 +147,8 @@ class GPTNeoSelfAttention(nn.Module):
self.register_buffer("bias", bias) self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9)) self.register_buffer("masked_bias", torch.tensor(-1e9))
self.attn_dropout = nn.Dropout(config.attention_dropout) self.attn_dropout = nn.Dropout(float(config.attention_dropout))
self.resid_dropout = nn.Dropout(config.resid_dropout) self.resid_dropout = nn.Dropout(float(config.resid_dropout))
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.num_heads = config.num_heads self.num_heads = config.num_heads
...@@ -188,7 +188,7 @@ class GPTNeoSelfAttention(nn.Module): ...@@ -188,7 +188,7 @@ class GPTNeoSelfAttention(nn.Module):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) attn_weights = torch.matmul(query, key.transpose(-1, -2))
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
if attention_mask is not None: if attention_mask is not None:
...@@ -290,7 +290,7 @@ class GPTNeoMLP(nn.Module): ...@@ -290,7 +290,7 @@ class GPTNeoMLP(nn.Module):
self.c_fc = nn.Linear(embed_dim, intermediate_size) self.c_fc = nn.Linear(embed_dim, intermediate_size)
self.c_proj = nn.Linear(intermediate_size, embed_dim) self.c_proj = nn.Linear(intermediate_size, embed_dim)
self.act = ACT2FN[config.activation_function] self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_dropout) self.dropout = nn.Dropout(float(config.resid_dropout))
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states) hidden_states = self.c_fc(hidden_states)
...@@ -475,7 +475,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): ...@@ -475,7 +475,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim) self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.drop = nn.Dropout(config.embed_dropout) self.drop = nn.Dropout(float(config.embed_dropout))
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)]) self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
...@@ -887,7 +887,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel): ...@@ -887,7 +887,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
"unexpected if using padding tokens in conjunction with `inputs_embeds.`" "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
) )
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None loss = None
if labels is not None: if labels is not None:
......
...@@ -69,7 +69,7 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None): ...@@ -69,7 +69,7 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
def rotate_every_two(x): def rotate_every_two(x):
x1 = x[:, :, :, ::2] x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2] x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), axis=-1) x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
...@@ -163,7 +163,7 @@ class GPTJAttention(nn.Module): ...@@ -163,7 +163,7 @@ class GPTJAttention(nn.Module):
# compute causal mask from causal mask buffer # compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
# Keep the attention weights computation in fp32 to avoid overflow issues # Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32) query = query.to(torch.float32)
...@@ -971,7 +971,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel): ...@@ -971,7 +971,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
"unexpected if using padding tokens in conjunction with `inputs_embeds.`" "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
) )
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None loss = None
if labels is not None: if labels is not None:
......
...@@ -226,9 +226,9 @@ class MobileBertEmbeddings(nn.Module): ...@@ -226,9 +226,9 @@ class MobileBertEmbeddings(nn.Module):
# dimensional output. # dimensional output.
inputs_embeds = torch.cat( inputs_embeds = torch.cat(
[ [
nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0), nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0),
inputs_embeds, inputs_embeds,
nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0), nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0.0),
], ],
dim=2, dim=2,
) )
......
...@@ -18,6 +18,7 @@ import collections ...@@ -18,6 +18,7 @@ import collections
import functools import functools
import inspect import inspect
import math import math
import operator
import random import random
import warnings import warnings
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
...@@ -26,6 +27,7 @@ import torch ...@@ -26,6 +27,7 @@ import torch
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.fx import Graph, GraphModule, Proxy, Tracer from torch.fx import Graph, GraphModule, Proxy, Tracer
from torch.fx.proxy import ParameterProxy
from .. import ( from .. import (
CONFIG_MAPPING, CONFIG_MAPPING,
...@@ -126,45 +128,45 @@ _SUPPORTED_MODELS = tuple( ...@@ -126,45 +128,45 @@ _SUPPORTED_MODELS = tuple(
) )
def embedding_override(self, input): def torch_nn_embedding(self, input):
return torch.empty(*input.shape, self.weight.shape[-1], device="meta") return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
def torch_nn_layernorm_override(self, input): def torch_nn_layernorm(self, input):
return input return input
def torch_nn_linear_override(self, input): def torch_nn_linear(self, input):
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta") return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
def torch_relu_override(x): def torch_relu(x):
return x return x
def torch_nn_relu_override(self, x): def torch_nn_relu(self, x):
return x return x
def torch_nn_functional_relu_override(x, inplace=False): def torch_nn_functional_relu(x, inplace=False):
if not inplace: if not inplace:
raise ValueError("Don't support in-place functional.relu for MetaTensor analysis") raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
return x return x
def torch_where_override(condition, x, y): def torch_where(condition, x, y):
# torch.where returns the broadcasted tensor of condition, x, and y, # torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition # so hack it by using addition
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta") return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
def torch_abs_override(input, *, out=None): def torch_abs(input, *, out=None):
if out is None: if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis") raise ValueError("Don't support in-place abs for MetaTensor analysis")
return input return input
def torch_arange_override(*args, **kwargs): def torch_arange(*args, **kwargs):
n = len(args) n = len(args)
step = 1 step = 1
if n == 1: if n == 1:
...@@ -179,7 +181,7 @@ def torch_arange_override(*args, **kwargs): ...@@ -179,7 +181,7 @@ def torch_arange_override(*args, **kwargs):
return torch.empty((end - start) // step, dtype=dtype, device="meta") return torch.empty((end - start) // step, dtype=dtype, device="meta")
def torch_cat_override(tensors, dim=None, axis=None, *, out=None): def torch_cat(tensors, dim=None, axis=None, *, out=None):
if dim is None and axis is None: if dim is None and axis is None:
dim = 0 dim = 0
if dim is None and axis is not None: if dim is None and axis is not None:
...@@ -193,7 +195,7 @@ def torch_cat_override(tensors, dim=None, axis=None, *, out=None): ...@@ -193,7 +195,7 @@ def torch_cat_override(tensors, dim=None, axis=None, *, out=None):
return torch.empty(final_shape, device="meta") return torch.empty(final_shape, device="meta")
def torch_stack_override(tensors, dim=None, axis=None, *, out=None): def torch_stack(tensors, dim=None, axis=None, *, out=None):
if dim is None and axis is None: if dim is None and axis is None:
dim = 0 dim = 0
if dim is None and axis is not None: if dim is None and axis is not None:
...@@ -205,7 +207,7 @@ def torch_stack_override(tensors, dim=None, axis=None, *, out=None): ...@@ -205,7 +207,7 @@ def torch_stack_override(tensors, dim=None, axis=None, *, out=None):
return torch.empty(shape, device="meta") return torch.empty(shape, device="meta")
def torch_add_override(input, other, *, alpha=1, out=None): def torch_add(input, other, *, alpha=1, out=None):
if not isinstance(input, torch.Tensor): if not isinstance(input, torch.Tensor):
return torch.empty_like(other, device="meta") return torch.empty_like(other, device="meta")
if not isinstance(other, torch.Tensor): if not isinstance(other, torch.Tensor):
...@@ -219,15 +221,15 @@ def torch_add_override(input, other, *, alpha=1, out=None): ...@@ -219,15 +221,15 @@ def torch_add_override(input, other, *, alpha=1, out=None):
return torch.empty(shape, device="meta") return torch.empty(shape, device="meta")
def torch_mul_override(input, other, *, out=None): def torch_mul(input, other, *, out=None):
return torch_add_override(input, other, out=out) return torch_add(input, other, out=out)
def torch_tensor_mul_override(self, other): def torch_tensor_mul(self, other):
return torch_mul_override(self, other) return torch_mul(self, other)
def torch_matmul_override(input, other, *, out=None): def torch_matmul(input, other, *, out=None):
d1 = input.dim() d1 = input.dim()
d2 = other.dim() d2 = other.dim()
shape = None shape = None
...@@ -263,7 +265,13 @@ def torch_matmul_override(input, other, *, out=None): ...@@ -263,7 +265,13 @@ def torch_matmul_override(input, other, *, out=None):
return torch.empty(*shape, device="meta") return torch.empty(*shape, device="meta")
def torch_tensor_repeat_override(self, *sizes): def torch_einsum(equation, *operands):
# TODO: infer shape without performing the computation, this might be quite hard.
concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
return torch.einsum(equation, *concrete_operands).to("meta")
def torch_tensor_repeat(self, *sizes):
shape = list(self.shape) shape = list(self.shape)
for i, x in enumerate(sizes): for i, x in enumerate(sizes):
shape[i] *= x shape[i] *= x
...@@ -305,6 +313,18 @@ def torch_nn_conv2d(self, input): ...@@ -305,6 +313,18 @@ def torch_nn_conv2d(self, input):
return torch.empty(shape, device="meta") return torch.empty(shape, device="meta")
def torch_unsqueeze(input, dim):
shape = list(input.shape)
if dim < 0:
dim = input.dim() + 1 + dim
shape.insert(dim, 1)
return torch.empty(shape, device="meta")
def torch_tensor_unsqueeze(self, dim):
return torch_unsqueeze(self, dim)
def torch_nn_mseloss(self, input, target): def torch_nn_mseloss(self, input, target):
if self.reduction == "none": if self.reduction == "none":
shape = target.shape shape = target.shape
...@@ -329,31 +349,42 @@ def torch_nn_bcewithlogitsloss(self, input, target): ...@@ -329,31 +349,42 @@ def torch_nn_bcewithlogitsloss(self, input, target):
return torch.empty(shape, device="meta") return torch.empty(shape, device="meta")
def operator_getitem(a, b):
if isinstance(a, torch.Tensor):
# TODO: infer shape without performing the computation.
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
return operator.getitem(a, b)
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = { _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.nn.Embedding: embedding_override, torch.nn.Embedding: torch_nn_embedding,
torch.nn.LayerNorm: torch_nn_layernorm_override, torch.nn.LayerNorm: torch_nn_layernorm,
torch.nn.Linear: torch_nn_linear_override, torch.nn.Linear: torch_nn_linear,
torch.relu: torch_relu_override, torch.relu: torch_relu,
torch.nn.functional.relu: torch_nn_functional_relu_override, torch.nn.functional.relu: torch_nn_functional_relu,
torch.nn.ReLU: torch_nn_relu_override, torch.nn.ReLU: torch_nn_relu,
torch.where: torch_where_override, torch.where: torch_where,
torch.abs: torch_abs_override, torch.abs: torch_abs,
torch.arange: torch_arange_override, torch.arange: torch_arange,
torch.cat: torch_cat_override, torch.cat: torch_cat,
torch.stack: torch_stack_override, torch.stack: torch_stack,
torch.add: torch_add_override, torch.add: torch_add,
torch.mul: torch_mul_override, torch.mul: torch_mul,
torch.Tensor.mul: torch_tensor_mul_override, torch.Tensor.mul: torch_tensor_mul,
torch.matmul: torch_matmul_override, torch.matmul: torch_matmul,
torch.Tensor.repeat: torch_tensor_repeat_override, torch.einsum: torch_einsum,
torch.Tensor.repeat: torch_tensor_repeat,
torch.roll: torch_roll, torch.roll: torch_roll,
# TODO: those might not be needed. # TODO: those might not be needed.
# torch.index_select: torch_index_select, # torch.index_select: torch_index_select,
# torch.Tensor.index_select: torch_tensor_index_select, # torch.Tensor.index_select: torch_tensor_index_select,
torch.nn.Conv2d: torch_nn_conv2d, torch.nn.Conv2d: torch_nn_conv2d,
torch.unsqueeze: torch_unsqueeze,
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
torch.nn.MSELoss: torch_nn_mseloss, torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss, torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss, torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
operator.getitem: operator_getitem,
} }
...@@ -371,7 +402,6 @@ class HFProxy(Proxy): ...@@ -371,7 +402,6 @@ class HFProxy(Proxy):
@property @property
def dtype(self): def dtype(self):
return self.tracer.root.dtype
if hasattr(self, "_metadata") and self._metadata is not None: if hasattr(self, "_metadata") and self._metadata is not None:
return self._metadata.dtype return self._metadata.dtype
return self.tracer.create_proxy("call_function", builtins.getattr, (self, "dtype"), {}) return self.tracer.create_proxy("call_function", builtins.getattr, (self, "dtype"), {})
...@@ -400,7 +430,7 @@ class HFProxy(Proxy): ...@@ -400,7 +430,7 @@ class HFProxy(Proxy):
return HFAttribute(self, k) return HFAttribute(self, k)
def __setitem__(self, indices, values): def __setitem__(self, indices, values):
return self.tracer.create_proxy("call_method", "__setitem__", (self, indices, values), {}) return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
def __contains__(self, key): def __contains__(self, key):
# To handle cases such as : # To handle cases such as :
...@@ -480,14 +510,14 @@ class HFTracer(Tracer): ...@@ -480,14 +510,14 @@ class HFTracer(Tracer):
regular PyTorch torch.fx.Proxy. regular PyTorch torch.fx.Proxy.
""" """
# Feature flag for proxying accesses to buffer values
proxy_buffer_attributes: bool = True
allow_insert_stateless_mods: bool = True allow_insert_stateless_mods: bool = True
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"] _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"]
def __init__(self, autowrap_modules=(math,), autowrap_functions=(), enable_cpatching=False): def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
super().__init__( super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions, enable_cpatching=enable_cpatching
)
if not is_torch_fx_available(): if not is_torch_fx_available():
torch_version = version.parse(importlib_metadata.version("torch")) torch_version = version.parse(importlib_metadata.version("torch"))
...@@ -500,7 +530,9 @@ class HFTracer(Tracer): ...@@ -500,7 +530,9 @@ class HFTracer(Tracer):
self, model: PreTrainedModel, input_name: str, shape: List[int] self, model: PreTrainedModel, input_name: str, shape: List[int]
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
"""Generates dummy input for model inference recording.""" """Generates dummy input for model inference recording."""
model_class = model.__class__ # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
# from pickle, or from the "__class__" attribute in the general case.
model_class = getattr(model, "class_for_deserialization", model.__class__)
device = model.device device = model.device
inputs_dict = {} inputs_dict = {}
...@@ -641,7 +673,38 @@ class HFTracer(Tracer): ...@@ -641,7 +673,38 @@ class HFTracer(Tracer):
if getattr(self, "_disable_module_getattr", False): if getattr(self, "_disable_module_getattr", False):
return attr_val return attr_val
else: else:
return super()._module_getattr(attr, attr_val, parameter_proxy_cache) # return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
for n, p in collection_to_search:
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
kwargs["proxy_factory_fn"] = (
None
if not self.param_shapes_constant
else lambda node: ParameterProxy(self, node, n, attr_val)
)
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(
attr_val, self.root.named_parameters(), parameter_proxy_cache
)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(
attr_val, self.root.named_buffers(), parameter_proxy_cache
)
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
return attr_val
def call_module(self, m, forward, args, kwargs): def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward self.orig_forward = forward
...@@ -693,17 +756,29 @@ class HFTracer(Tracer): ...@@ -693,17 +756,29 @@ class HFTracer(Tracer):
for name, (_, orig) in self.patched_torch_methods.items(): for name, (_, orig) in self.patched_torch_methods.items():
setattr(torch, name, orig) setattr(torch, name, orig)
# TODO: keep this until necessary.
# This is necessary because concrete args are added as input to the traced module since # This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888. # https://github.com/pytorch/pytorch/pull/55888.
# A PR that solves this was posted: https://github.com/pytorch/pytorch/pull/59569 but it was not merged yet.
for node in self.graph.nodes: for node in self.graph.nodes:
if node.op == "placeholder": if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them. # Removing default values for inputs as the forward pass will fail with them.
if node.target in input_names: if node.target in input_names:
node.args = () node.args = ()
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
# It cannot infer on the attributes and methods the input should have, and fails.
node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed. # It is a concrete arg so it is not used and should be removed.
else: else:
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
# Newer versions of torch.fx emit an assert statement
# for concrete arguments; delete those before we delete
# the concrete arg.
to_delete = []
for user in node.users:
if user.target == torch.fx._symbolic_trace._assert_is_none:
to_delete.append(user)
for user in to_delete:
self.graph.erase_node(user)
self.graph.erase_node(node) self.graph.erase_node(node)
# TODO: solves GraphModule creation. # TODO: solves GraphModule creation.
...@@ -809,4 +884,10 @@ def symbolic_trace( ...@@ -809,4 +884,10 @@ def symbolic_trace(
traced_graph = tracer.trace(model, concrete_args=concrete_args) traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph) traced = torch.fx.GraphModule(model, traced_graph)
traced.config = model.config
# The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus
# _generate_dummy_input, where the model class is needed.
traced.class_for_deserialization = model.__class__
traced.device = model.device
return traced return traced
...@@ -325,7 +325,7 @@ torch_version = None ...@@ -325,7 +325,7 @@ torch_version = None
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False _torch_fx_available = _torch_onnx_dict_inputs_support_available = False
if _torch_available: if _torch_available:
torch_version = version.parse(importlib_metadata.version("torch")) torch_version = version.parse(importlib_metadata.version("torch"))
_torch_fx_available = (torch_version.major, torch_version.minor) == ( _torch_fx_available = (torch_version.major, torch_version.minor) >= (
TORCH_FX_REQUIRED_VERSION.major, TORCH_FX_REQUIRED_VERSION.major,
TORCH_FX_REQUIRED_VERSION.minor, TORCH_FX_REQUIRED_VERSION.minor,
) )
......
...@@ -16,11 +16,14 @@ ...@@ -16,11 +16,14 @@
import copy import copy
import inspect import inspect
import os
import pickle
import tempfile
import unittest import unittest
from transformers import SwinConfig from transformers import SwinConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
...@@ -38,6 +41,9 @@ if is_vision_available(): ...@@ -38,6 +41,9 @@ if is_vision_available():
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor
if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace
def _config_zero_init(config): def _config_zero_init(config):
configs_no_init = copy.deepcopy(config) configs_no_init = copy.deepcopy(config)
...@@ -381,6 +387,97 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -381,6 +387,97 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
if not is_torch_fx_available() or not self.fx_compatible:
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
try:
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None)
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
if labels is not None:
input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
end_positions = inputs.get("end_positions", None)
if labels is not None:
input_names.append("labels")
if start_positions is not None:
input_names.append("start_positions")
if end_positions is not None:
input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = filtered_inputs.keys()
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
except RuntimeError as e:
self.fail(f"Couldn't trace module: {e}")
def flatten_output(output):
flatten = []
for x in output:
if isinstance(x, (tuple, list)):
flatten += flatten_output(x)
elif not isinstance(x, torch.Tensor):
continue
else:
flatten.append(x)
return flatten
model_output = flatten_output(model_output)
traced_output = flatten_output(traced_output)
num_outputs = len(model_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], traced_output[i]),
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)
@require_vision @require_vision
@require_torch @require_torch
......
...@@ -19,6 +19,7 @@ import inspect ...@@ -19,6 +19,7 @@ import inspect
import json import json
import os import os
import os.path import os.path
import pickle
import random import random
import sys import sys
import tempfile import tempfile
...@@ -758,8 +759,8 @@ class ModelTesterMixin: ...@@ -758,8 +759,8 @@ class ModelTesterMixin:
traced_model = symbolic_trace(model, input_names) traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs) traced_output = traced_model(**filtered_inputs)
except RuntimeError: except RuntimeError as e:
self.fail("Couldn't trace module.") self.fail(f"Couldn't trace module: {e}")
def flatten_output(output): def flatten_output(output):
flatten = [] flatten = []
...@@ -782,6 +783,40 @@ class ModelTesterMixin: ...@@ -782,6 +783,40 @@ class ModelTesterMixin:
f"traced {i}th output doesn't match model {i}th output for {model_class}", f"traced {i}th output doesn't match model {i}th output for {model_class}",
) )
# Test that the model can be TorchScripted
try:
scripted = torch.jit.script(traced_model)
except Exception as e:
self.fail(f"Could not TorchScript the traced model: {e}")
scripted_output = scripted(**filtered_inputs)
scripted_output = flatten_output(scripted_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], scripted_output[i]),
f"scripted {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)
def test_headmasking(self): def test_headmasking(self):
if not self.test_head_masking: if not self.test_head_masking:
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment