Unverified Commit 2e7e4280 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

Traced models serialization and torchscripting fix (#17206)

* Fix torch.jit.script and pickling issues

* Fix get_attr issues

* Fix import in function

* Fix GPT-J and T5 tracing for torch=1.11

* Gate graph surgery on torch version

* Modeling minor changes to enable TorchScripting

* Model serialization / deserialization test

* Remove _assert_is_none users
parent 1cd01b0a
......@@ -187,7 +187,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
if attention_mask is not None:
......
......@@ -211,7 +211,7 @@ class MultiHeadSelfAttention(nn.Module):
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(mask, -float("inf")) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(mask, torch.tensor(-float("inf"))) # (bs, n_heads, q_length, k_length)
weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
......
......@@ -198,7 +198,7 @@ class GPT2Attention(nn.Module):
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
if attention_mask is not None:
......@@ -1410,7 +1410,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
......
......@@ -147,8 +147,8 @@ class GPTNeoSelfAttention(nn.Module):
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9))
self.attn_dropout = nn.Dropout(config.attention_dropout)
self.resid_dropout = nn.Dropout(config.resid_dropout)
self.attn_dropout = nn.Dropout(float(config.attention_dropout))
self.resid_dropout = nn.Dropout(float(config.resid_dropout))
self.embed_dim = config.hidden_size
self.num_heads = config.num_heads
......@@ -188,7 +188,7 @@ class GPTNeoSelfAttention(nn.Module):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
if attention_mask is not None:
......@@ -290,7 +290,7 @@ class GPTNeoMLP(nn.Module):
self.c_fc = nn.Linear(embed_dim, intermediate_size)
self.c_proj = nn.Linear(intermediate_size, embed_dim)
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_dropout)
self.dropout = nn.Dropout(float(config.resid_dropout))
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
......@@ -475,7 +475,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
self.embed_dim = config.hidden_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.drop = nn.Dropout(config.embed_dropout)
self.drop = nn.Dropout(float(config.embed_dropout))
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
......@@ -887,7 +887,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
......
......@@ -69,7 +69,7 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), axis=-1)
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
......@@ -163,7 +163,7 @@ class GPTJAttention(nn.Module):
# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
......@@ -971,7 +971,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
......
......@@ -226,9 +226,9 @@ class MobileBertEmbeddings(nn.Module):
# dimensional output.
inputs_embeds = torch.cat(
[
nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0),
nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0),
inputs_embeds,
nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0),
nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0.0),
],
dim=2,
)
......
......@@ -18,6 +18,7 @@ import collections
import functools
import inspect
import math
import operator
import random
import warnings
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
......@@ -26,6 +27,7 @@ import torch
from packaging import version
from torch import nn
from torch.fx import Graph, GraphModule, Proxy, Tracer
from torch.fx.proxy import ParameterProxy
from .. import (
CONFIG_MAPPING,
......@@ -126,45 +128,45 @@ _SUPPORTED_MODELS = tuple(
)
def embedding_override(self, input):
def torch_nn_embedding(self, input):
return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
def torch_nn_layernorm_override(self, input):
def torch_nn_layernorm(self, input):
return input
def torch_nn_linear_override(self, input):
def torch_nn_linear(self, input):
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
def torch_relu_override(x):
def torch_relu(x):
return x
def torch_nn_relu_override(self, x):
def torch_nn_relu(self, x):
return x
def torch_nn_functional_relu_override(x, inplace=False):
def torch_nn_functional_relu(x, inplace=False):
if not inplace:
raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
return x
def torch_where_override(condition, x, y):
def torch_where(condition, x, y):
# torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
def torch_abs_override(input, *, out=None):
if out is None:
def torch_abs(input, *, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
return input
def torch_arange_override(*args, **kwargs):
def torch_arange(*args, **kwargs):
n = len(args)
step = 1
if n == 1:
......@@ -179,7 +181,7 @@ def torch_arange_override(*args, **kwargs):
return torch.empty((end - start) // step, dtype=dtype, device="meta")
def torch_cat_override(tensors, dim=None, axis=None, *, out=None):
def torch_cat(tensors, dim=None, axis=None, *, out=None):
if dim is None and axis is None:
dim = 0
if dim is None and axis is not None:
......@@ -193,7 +195,7 @@ def torch_cat_override(tensors, dim=None, axis=None, *, out=None):
return torch.empty(final_shape, device="meta")
def torch_stack_override(tensors, dim=None, axis=None, *, out=None):
def torch_stack(tensors, dim=None, axis=None, *, out=None):
if dim is None and axis is None:
dim = 0
if dim is None and axis is not None:
......@@ -205,7 +207,7 @@ def torch_stack_override(tensors, dim=None, axis=None, *, out=None):
return torch.empty(shape, device="meta")
def torch_add_override(input, other, *, alpha=1, out=None):
def torch_add(input, other, *, alpha=1, out=None):
if not isinstance(input, torch.Tensor):
return torch.empty_like(other, device="meta")
if not isinstance(other, torch.Tensor):
......@@ -219,15 +221,15 @@ def torch_add_override(input, other, *, alpha=1, out=None):
return torch.empty(shape, device="meta")
def torch_mul_override(input, other, *, out=None):
return torch_add_override(input, other, out=out)
def torch_mul(input, other, *, out=None):
return torch_add(input, other, out=out)
def torch_tensor_mul_override(self, other):
return torch_mul_override(self, other)
def torch_tensor_mul(self, other):
return torch_mul(self, other)
def torch_matmul_override(input, other, *, out=None):
def torch_matmul(input, other, *, out=None):
d1 = input.dim()
d2 = other.dim()
shape = None
......@@ -263,7 +265,13 @@ def torch_matmul_override(input, other, *, out=None):
return torch.empty(*shape, device="meta")
def torch_tensor_repeat_override(self, *sizes):
def torch_einsum(equation, *operands):
# TODO: infer shape without performing the computation, this might be quite hard.
concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
return torch.einsum(equation, *concrete_operands).to("meta")
def torch_tensor_repeat(self, *sizes):
shape = list(self.shape)
for i, x in enumerate(sizes):
shape[i] *= x
......@@ -305,6 +313,18 @@ def torch_nn_conv2d(self, input):
return torch.empty(shape, device="meta")
def torch_unsqueeze(input, dim):
shape = list(input.shape)
if dim < 0:
dim = input.dim() + 1 + dim
shape.insert(dim, 1)
return torch.empty(shape, device="meta")
def torch_tensor_unsqueeze(self, dim):
return torch_unsqueeze(self, dim)
def torch_nn_mseloss(self, input, target):
if self.reduction == "none":
shape = target.shape
......@@ -329,31 +349,42 @@ def torch_nn_bcewithlogitsloss(self, input, target):
return torch.empty(shape, device="meta")
def operator_getitem(a, b):
if isinstance(a, torch.Tensor):
# TODO: infer shape without performing the computation.
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
return operator.getitem(a, b)
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.nn.Embedding: embedding_override,
torch.nn.LayerNorm: torch_nn_layernorm_override,
torch.nn.Linear: torch_nn_linear_override,
torch.relu: torch_relu_override,
torch.nn.functional.relu: torch_nn_functional_relu_override,
torch.nn.ReLU: torch_nn_relu_override,
torch.where: torch_where_override,
torch.abs: torch_abs_override,
torch.arange: torch_arange_override,
torch.cat: torch_cat_override,
torch.stack: torch_stack_override,
torch.add: torch_add_override,
torch.mul: torch_mul_override,
torch.Tensor.mul: torch_tensor_mul_override,
torch.matmul: torch_matmul_override,
torch.Tensor.repeat: torch_tensor_repeat_override,
torch.nn.Embedding: torch_nn_embedding,
torch.nn.LayerNorm: torch_nn_layernorm,
torch.nn.Linear: torch_nn_linear,
torch.relu: torch_relu,
torch.nn.functional.relu: torch_nn_functional_relu,
torch.nn.ReLU: torch_nn_relu,
torch.where: torch_where,
torch.abs: torch_abs,
torch.arange: torch_arange,
torch.cat: torch_cat,
torch.stack: torch_stack,
torch.add: torch_add,
torch.mul: torch_mul,
torch.Tensor.mul: torch_tensor_mul,
torch.matmul: torch_matmul,
torch.einsum: torch_einsum,
torch.Tensor.repeat: torch_tensor_repeat,
torch.roll: torch_roll,
# TODO: those might not be needed.
# torch.index_select: torch_index_select,
# torch.Tensor.index_select: torch_tensor_index_select,
torch.nn.Conv2d: torch_nn_conv2d,
torch.unsqueeze: torch_unsqueeze,
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
operator.getitem: operator_getitem,
}
......@@ -371,7 +402,6 @@ class HFProxy(Proxy):
@property
def dtype(self):
return self.tracer.root.dtype
if hasattr(self, "_metadata") and self._metadata is not None:
return self._metadata.dtype
return self.tracer.create_proxy("call_function", builtins.getattr, (self, "dtype"), {})
......@@ -400,7 +430,7 @@ class HFProxy(Proxy):
return HFAttribute(self, k)
def __setitem__(self, indices, values):
return self.tracer.create_proxy("call_method", "__setitem__", (self, indices, values), {})
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
def __contains__(self, key):
# To handle cases such as :
......@@ -480,14 +510,14 @@ class HFTracer(Tracer):
regular PyTorch torch.fx.Proxy.
"""
# Feature flag for proxying accesses to buffer values
proxy_buffer_attributes: bool = True
allow_insert_stateless_mods: bool = True
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"]
def __init__(self, autowrap_modules=(math,), autowrap_functions=(), enable_cpatching=False):
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
super().__init__(
autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions, enable_cpatching=enable_cpatching
)
super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
if not is_torch_fx_available():
torch_version = version.parse(importlib_metadata.version("torch"))
......@@ -500,7 +530,9 @@ class HFTracer(Tracer):
self, model: PreTrainedModel, input_name: str, shape: List[int]
) -> Dict[str, torch.Tensor]:
"""Generates dummy input for model inference recording."""
model_class = model.__class__
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
# from pickle, or from the "__class__" attribute in the general case.
model_class = getattr(model, "class_for_deserialization", model.__class__)
device = model.device
inputs_dict = {}
......@@ -641,7 +673,38 @@ class HFTracer(Tracer):
if getattr(self, "_disable_module_getattr", False):
return attr_val
else:
return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
# return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
for n, p in collection_to_search:
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
kwargs["proxy_factory_fn"] = (
None
if not self.param_shapes_constant
else lambda node: ParameterProxy(self, node, n, attr_val)
)
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(
attr_val, self.root.named_parameters(), parameter_proxy_cache
)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(
attr_val, self.root.named_buffers(), parameter_proxy_cache
)
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
return attr_val
def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward
......@@ -693,17 +756,29 @@ class HFTracer(Tracer):
for name, (_, orig) in self.patched_torch_methods.items():
setattr(torch, name, orig)
# TODO: keep this until necessary.
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
# A PR that solves this was posted: https://github.com/pytorch/pytorch/pull/59569 but it was not merged yet.
for node in self.graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in input_names:
node.args = ()
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
# It cannot infer on the attributes and methods the input should have, and fails.
node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed.
else:
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
# Newer versions of torch.fx emit an assert statement
# for concrete arguments; delete those before we delete
# the concrete arg.
to_delete = []
for user in node.users:
if user.target == torch.fx._symbolic_trace._assert_is_none:
to_delete.append(user)
for user in to_delete:
self.graph.erase_node(user)
self.graph.erase_node(node)
# TODO: solves GraphModule creation.
......@@ -809,4 +884,10 @@ def symbolic_trace(
traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph)
traced.config = model.config
# The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus
# _generate_dummy_input, where the model class is needed.
traced.class_for_deserialization = model.__class__
traced.device = model.device
return traced
......@@ -325,7 +325,7 @@ torch_version = None
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
if _torch_available:
torch_version = version.parse(importlib_metadata.version("torch"))
_torch_fx_available = (torch_version.major, torch_version.minor) == (
_torch_fx_available = (torch_version.major, torch_version.minor) >= (
TORCH_FX_REQUIRED_VERSION.major,
TORCH_FX_REQUIRED_VERSION.minor,
)
......
......@@ -16,11 +16,14 @@
import copy
import inspect
import os
import pickle
import tempfile
import unittest
from transformers import SwinConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
......@@ -38,6 +41,9 @@ if is_vision_available():
from transformers import AutoFeatureExtractor
if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
......@@ -381,6 +387,97 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
if not is_torch_fx_available() or not self.fx_compatible:
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
try:
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None)
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
if labels is not None:
input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
end_positions = inputs.get("end_positions", None)
if labels is not None:
input_names.append("labels")
if start_positions is not None:
input_names.append("start_positions")
if end_positions is not None:
input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = filtered_inputs.keys()
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
except RuntimeError as e:
self.fail(f"Couldn't trace module: {e}")
def flatten_output(output):
flatten = []
for x in output:
if isinstance(x, (tuple, list)):
flatten += flatten_output(x)
elif not isinstance(x, torch.Tensor):
continue
else:
flatten.append(x)
return flatten
model_output = flatten_output(model_output)
traced_output = flatten_output(traced_output)
num_outputs = len(model_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], traced_output[i]),
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)
@require_vision
@require_torch
......
......@@ -19,6 +19,7 @@ import inspect
import json
import os
import os.path
import pickle
import random
import sys
import tempfile
......@@ -758,8 +759,8 @@ class ModelTesterMixin:
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
except RuntimeError:
self.fail("Couldn't trace module.")
except RuntimeError as e:
self.fail(f"Couldn't trace module: {e}")
def flatten_output(output):
flatten = []
......@@ -782,6 +783,40 @@ class ModelTesterMixin:
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be TorchScripted
try:
scripted = torch.jit.script(traced_model)
except Exception as e:
self.fail(f"Could not TorchScript the traced model: {e}")
scripted_output = scripted(**filtered_inputs)
scripted_output = flatten_output(scripted_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], scripted_output[i]),
f"scripted {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)
def test_headmasking(self):
if not self.test_head_masking:
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment