Unverified Commit fbabd674 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

Fix for Neuron (#30259)

parent 5cf3e6bf
...@@ -1010,8 +1010,11 @@ class CohereModel(CoherePreTrainedModel): ...@@ -1010,8 +1010,11 @@ class CohereModel(CoherePreTrainedModel):
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2: if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1] mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
elif attention_mask.dim() == 4: elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only. # cache. In that case, the 4D attention mask attends to the newest tokens only.
......
...@@ -1001,8 +1001,11 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -1001,8 +1001,11 @@ class GemmaModel(GemmaPreTrainedModel):
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2: if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1] mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
elif attention_mask.dim() == 4: elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only. # cache. In that case, the 4D attention mask attends to the newest tokens only.
......
...@@ -1089,8 +1089,11 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1089,8 +1089,11 @@ class LlamaModel(LlamaPreTrainedModel):
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2: if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1] mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
elif attention_mask.dim() == 4: elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only. # cache. In that case, the 4D attention mask attends to the newest tokens only.
......
...@@ -1068,8 +1068,11 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1068,8 +1068,11 @@ class OlmoModel(OlmoPreTrainedModel):
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2: if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1] mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
elif attention_mask.dim() == 4: elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only. # cache. In that case, the 4D attention mask attends to the newest tokens only.
......
...@@ -84,12 +84,12 @@ if is_torch_neuroncore_available(check_device=False): ...@@ -84,12 +84,12 @@ if is_torch_neuroncore_available(check_device=False):
if os.environ.get("TORCHELASTIC_RUN_ID"): if os.environ.get("TORCHELASTIC_RUN_ID"):
if is_optimum_neuron_available(): if is_optimum_neuron_available():
logger.info( logger.info(
"Make sure that you are performing the training with the TrainiumTrainer from optimum[neuron], this " "Make sure that you are performing the training with the NeuronTrainer from optimum[neuron], this "
"will fail otherwise." "will fail otherwise."
) )
else: else:
logger.warning( logger.warning(
"Please use the TrainiumTrainer from optimum[neuron] instead of the Transformers library to perform " "Please use the NeuronTrainer from optimum[neuron] instead of the Transformers library to perform "
"training on AWS Trainium instances. More information here: " "training on AWS Trainium instances. More information here: "
"https://github.com/huggingface/optimum-neuron" "https://github.com/huggingface/optimum-neuron"
) )
......
...@@ -15,22 +15,28 @@ ...@@ -15,22 +15,28 @@
import builtins import builtins
import collections import collections
import contextlib
import functools import functools
import inspect import inspect
import math import math
import operator import operator
import os import os
import random import random
import sys
import warnings import warnings
from typing import Any, Callable, Dict, List, Optional, Type, Union from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
import torch import torch
import torch.utils._pytree as pytree
from torch import nn from torch import nn
from torch.fx import Graph, GraphModule, Proxy, Tracer from torch.fx import Graph, GraphModule, Node, Proxy, Tracer
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
from torch.fx._symbolic_trace import is_fx_tracing
from torch.fx.proxy import ParameterProxy from torch.fx.proxy import ParameterProxy
from .. import PretrainedConfig, PreTrainedModel, logging from .. import logging
from ..cache_utils import Cache, DynamicCache, SinkCache, StaticCache
from ..modeling_utils import PretrainedConfig, PreTrainedModel
from ..models.auto import get_values from ..models.auto import get_values
from ..models.auto.modeling_auto import ( from ..models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
...@@ -55,7 +61,7 @@ from ..models.auto.modeling_auto import ( ...@@ -55,7 +61,7 @@ from ..models.auto.modeling_auto import (
MODEL_MAPPING_NAMES, MODEL_MAPPING_NAMES,
) )
from ..pytorch_utils import is_torch_greater_or_equal_than_2_0 from ..pytorch_utils import is_torch_greater_or_equal_than_2_0
from ..utils import ( from .import_utils import (
ENV_VARS_TRUE_VALUES, ENV_VARS_TRUE_VALUES,
TORCH_FX_REQUIRED_VERSION, TORCH_FX_REQUIRED_VERSION,
get_torch_version, get_torch_version,
...@@ -192,6 +198,8 @@ _SPECIAL_SUPPORTED_MODELS = [ ...@@ -192,6 +198,8 @@ _SPECIAL_SUPPORTED_MODELS = [
] ]
_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS))) _SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
_CURRENT_TRACER = None
def torch_nn_embedding(self, input): def torch_nn_embedding(self, input):
return torch.empty(*input.shape, self.weight.shape[-1], device="meta", dtype=self.weight.dtype) return torch.empty(*input.shape, self.weight.shape[-1], device="meta", dtype=self.weight.dtype)
...@@ -701,6 +709,92 @@ class MetaDeviceAttribute(HFAttribute): ...@@ -701,6 +709,92 @@ class MetaDeviceAttribute(HFAttribute):
pass pass
class HFCacheProxy(HFProxy):
"""
Proxy that represents an instance of `transformers.cache_utils.Cache`.
"""
@property
def __class__(self):
return ProxyableCache
def create_wrapper(
function: Callable,
op_type: Union[Literal["call_function"], Literal["call_method"], Literal["get_attr"]],
proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None,
) -> Callable:
@functools.wraps(function)
def wrapper(*args, **kwargs):
if not is_fx_tracing():
return function(*args, **kwargs)
found_proxies = []
def check_proxy(a):
if isinstance(a, Proxy):
found_proxies.append(a)
torch.fx.node.map_aggregate(args, check_proxy)
torch.fx.node.map_aggregate(kwargs, check_proxy)
if len(found_proxies) > 0:
tracer = found_proxies[0].tracer
if op_type == "call_function":
target = function
elif op_type == "call_method":
target = function.__name__
elif op_type == "get_attr":
target = function.__name__
else:
raise ValueError(f"op_type {op_type} not supported.")
return tracer.create_proxy(op_type, target, args, kwargs, proxy_factory_fn=proxy_factory_fn)
else:
return function(*args, **kwargs)
return wrapper
class HFProxyableClassMeta(type):
"""
Metaclass that creates a class with its main methods wrapped to be proxyable.
"""
def __new__(
cls,
name: str,
bases: Tuple[Type, ...],
attrs: Dict[str, Any],
proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None,
):
cls = super().__new__(cls, name, bases, attrs)
for attr_name in dir(cls):
attr = getattr(cls, attr_name, None)
if attr is None:
continue
if attr_name == "__init__":
op_type = "call_function"
elif attr_name.startswith("__"):
op_type = None
elif inspect.ismethod(attr):
op_type = "call_function"
elif inspect.isfunction(attr):
op_type = "call_method"
else:
op_type = None
if op_type is not None:
setattr(cls, attr_name, create_wrapper(attr, op_type, proxy_factory_fn=proxy_factory_fn))
return cls
def gen_constructor_wrapper(target: Callable) -> Tuple[Callable, Callable]:
"""
Wraps `target` to be proxyable. Used for tensor creators like `torch.ones`, `torch.arange` and so on.
"""
wrapper = create_wrapper(target, "call_function")
return wrapper, target
def _proxies_to_metas(v): def _proxies_to_metas(v):
"""Returns the underlying metadata for HFProxies, and behaves like the identity for the others.""" """Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
if isinstance(v, MetaDeviceAttribute): if isinstance(v, MetaDeviceAttribute):
...@@ -712,25 +806,24 @@ def _proxies_to_metas(v): ...@@ -712,25 +806,24 @@ def _proxies_to_metas(v):
return v return v
def _gen_constructor_wrapper(target): def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
@functools.wraps(target) global _CURRENT_TRACER
def wrapper(*args, **kwargs): if not isinstance(_CURRENT_TRACER, HFTracer):
proxy = None raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
return HFCacheProxy(n, _CURRENT_TRACER)
def check_has_proxy(v):
if isinstance(v, Proxy):
nonlocal proxy
proxy = v
torch.fx.node.map_aggregate(args, check_has_proxy) # Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
torch.fx.node.map_aggregate(kwargs, check_has_proxy) ProxyableCache = HFProxyableClassMeta("ProxyableCache", (Cache,), {}, proxy_factory_fn=cache_proxy_factory_fn)
ProxyableDynamicCache = HFProxyableClassMeta(
if proxy is not None: "ProxyableDynamicCache", (DynamicCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
return proxy.tracer.create_proxy("call_function", target, args, kwargs) )
else: ProxyableSinkCache = HFProxyableClassMeta(
return target(*args, **kwargs) "ProxyableSinkCache", (SinkCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
)
return wrapper, target ProxyableStaticCache = HFProxyableClassMeta(
"ProxyableStaticCache", (StaticCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
)
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None): def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
...@@ -764,6 +857,13 @@ class HFTracer(Tracer): ...@@ -764,6 +857,13 @@ class HFTracer(Tracer):
"finfo", "finfo",
"tril", "tril",
] ]
_CLASSES_TO_PATCH = {
Cache: ProxyableCache,
DynamicCache: ProxyableDynamicCache,
SinkCache: ProxyableSinkCache,
StaticCache: ProxyableStaticCache,
}
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
def __init__(self, autowrap_modules=(math,), autowrap_functions=()): def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
...@@ -776,7 +876,7 @@ class HFTracer(Tracer): ...@@ -776,7 +876,7 @@ class HFTracer(Tracer):
) )
def _generate_dummy_input( def _generate_dummy_input(
self, model: PreTrainedModel, input_name: str, shape: List[int], input_names: List[str] self, model: "PreTrainedModel", input_name: str, shape: List[int], input_names: List[str]
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
"""Generates dummy input for model inference recording.""" """Generates dummy input for model inference recording."""
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
...@@ -951,6 +1051,11 @@ class HFTracer(Tracer): ...@@ -951,6 +1051,11 @@ class HFTracer(Tracer):
args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas) args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas)
kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas) kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas)
should_install_metadata = True
self._disable_module_getattr = True
self._disable_call_module = True
if kind == "call_function": if kind == "call_function":
meta_target = _MANUAL_META_OVERRIDES.get(target, target) meta_target = _MANUAL_META_OVERRIDES.get(target, target)
meta_out = meta_target(*args_metas, **kwargs_metas) meta_out = meta_target(*args_metas, **kwargs_metas)
...@@ -963,19 +1068,13 @@ class HFTracer(Tracer): ...@@ -963,19 +1068,13 @@ class HFTracer(Tracer):
elif kind == "call_module": elif kind == "call_module":
if not hasattr(self, "orig_forward"): if not hasattr(self, "orig_forward"):
raise AttributeError(f"{self} does not have an attribute called orig_forward") raise AttributeError(f"{self} does not have an attribute called orig_forward")
self._disable_module_getattr = True
try:
mod = self.root.get_submodule(target) mod = self.root.get_submodule(target)
mod_type = type(mod) mod_type = type(mod)
if mod_type in _MANUAL_META_OVERRIDES: if mod_type in _MANUAL_META_OVERRIDES:
meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas) meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
else: else:
meta_out = self.orig_forward(*args_metas, **kwargs_metas) meta_out = self.orig_forward(*args_metas, **kwargs_metas)
finally:
self._disable_module_getattr = False
elif kind == "get_attr": elif kind == "get_attr":
self._disable_module_getattr = True
try:
attr_itr = self.root attr_itr = self.root
atoms = target.split(".") atoms = target.split(".")
for atom in atoms: for atom in atoms:
...@@ -984,18 +1083,21 @@ class HFTracer(Tracer): ...@@ -984,18 +1083,21 @@ class HFTracer(Tracer):
meta_out = attr_itr.to(device="meta") meta_out = attr_itr.to(device="meta")
else: else:
meta_out = attr_itr meta_out = attr_itr
finally:
self._disable_module_getattr = False
else: else:
return rv should_install_metadata = False
if should_install_metadata:
if not isinstance(rv, Proxy): if not isinstance(rv, Proxy):
raise ValueError("Don't support composite output yet") raise ValueError("Don't support composite output yet")
rv.install_metadata(meta_out) rv.install_metadata(meta_out)
except Exception as e: except Exception as e:
if _IS_IN_DEBUG_MODE: if _IS_IN_DEBUG_MODE:
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}") warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
self._disable_module_getattr = False
self._disable_call_module = False
return rv return rv
# Replaced by .getattr from PyTorch 1.13 # Replaced by .getattr from PyTorch 1.13
...@@ -1041,12 +1143,51 @@ class HFTracer(Tracer): ...@@ -1041,12 +1143,51 @@ class HFTracer(Tracer):
return self._module_getattr(attr, attr_val, parameter_proxy_cache) return self._module_getattr(attr, attr_val, parameter_proxy_cache)
def call_module(self, m, forward, args, kwargs): def call_module(self, m, forward, args, kwargs):
if getattr(self, "_disable_call_module", False):
return forward(*args, **kwargs)
self.orig_forward = forward self.orig_forward = forward
return super().call_module(m, forward, args, kwargs) return super().call_module(m, forward, args, kwargs)
def proxy(self, node): def proxy(self, node):
return HFProxy(node, self) return HFProxy(node, self)
@contextlib.contextmanager
def patch_for_tracing(self, root: Union[torch.nn.Module, Callable[..., Any]]):
# Patching torch functions
self.patched_torch_methods = {
target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
}
self.orig_fns = set()
for name, (wrapper, orig) in self.patched_torch_methods.items():
setattr(torch, name, wrapper)
self.orig_fns.add(orig)
# Patching classes
patched = []
module_of_model = inspect.getmodule(root)
for name, mod in sys.modules.items():
if module_of_model is not None and mod is not module_of_model:
continue
if not name.startswith("transformers"):
continue
for orig_cls, patched_cls in self._CLASSES_TO_PATCH.items():
for attr_name, attr in mod.__dict__.items():
if attr is orig_cls:
patched.append((mod, attr_name, orig_cls))
setattr(mod, attr_name, patched_cls)
yield
# Restoring patched functions and classes.
for name, (_, orig) in self.patched_torch_methods.items():
setattr(torch, name, orig)
self.patched_torch_methods = {}
self.orig_fns = set()
for mod, attr_name, orig_cls in patched:
setattr(mod, attr_name, orig_cls)
def trace( def trace(
self, self,
root: Union[torch.nn.Module, Callable[..., Any]], root: Union[torch.nn.Module, Callable[..., Any]],
...@@ -1125,28 +1266,25 @@ class HFTracer(Tracer): ...@@ -1125,28 +1266,25 @@ class HFTracer(Tracer):
" transformers.PreTrainedModel." " transformers.PreTrainedModel."
) )
concrete_metas = { def to_meta(value):
input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_ if isinstance(value, torch.Tensor):
for input_name, input_ in inputs.items() return value.to("meta")
} return value
concrete_metas = pytree.tree_map(to_meta, inputs)
for param in sig.parameters.values(): for param in sig.parameters.values():
if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names: if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
concrete_metas[f"**{param.name}"] = {} concrete_metas[f"**{param.name}"] = {}
self.meta_args = concrete_metas self.meta_args = concrete_metas
self.patched_torch_methods = {
target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
}
self.orig_fns = set()
for name, (wrapper, orig) in self.patched_torch_methods.items():
setattr(torch, name, wrapper)
self.orig_fns.add(orig)
global _CURRENT_TRACER
_CURRENT_TRACER = self
with self.patch_for_tracing(root):
try: try:
self.graph = super().trace(root, concrete_args=concrete_args) self.graph = super().trace(root, concrete_args=concrete_args)
finally: finally:
for name, (_, orig) in self.patched_torch_methods.items(): _CURRENT_TRACER = None
setattr(torch, name, orig)
# This is necessary because concrete args are added as input to the traced module since # This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888. # https://github.com/pytorch/pytorch/pull/55888.
...@@ -1256,11 +1394,11 @@ def get_concrete_args(model: nn.Module, input_names: List[str]): ...@@ -1256,11 +1394,11 @@ def get_concrete_args(model: nn.Module, input_names: List[str]):
return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
def is_model_supported(model: PreTrainedModel): def is_model_supported(model: "PreTrainedModel"):
return model.__class__.__name__ in _SUPPORTED_MODELS return model.__class__.__name__ in _SUPPORTED_MODELS
def check_if_model_is_supported(model: PreTrainedModel): def check_if_model_is_supported(model: "PreTrainedModel"):
if not is_model_supported(model): if not is_model_supported(model):
supported_model_names = ", ".join(_SUPPORTED_MODELS) supported_model_names = ", ".join(_SUPPORTED_MODELS)
raise NotImplementedError( raise NotImplementedError(
...@@ -1269,7 +1407,7 @@ def check_if_model_is_supported(model: PreTrainedModel): ...@@ -1269,7 +1407,7 @@ def check_if_model_is_supported(model: PreTrainedModel):
def symbolic_trace( def symbolic_trace(
model: PreTrainedModel, model: "PreTrainedModel",
input_names: Optional[List[str]] = None, input_names: Optional[List[str]] = None,
disable_check: bool = False, disable_check: bool = False,
tracer_cls: Type[HFTracer] = HFTracer, tracer_cls: Type[HFTracer] = HFTracer,
...@@ -1307,6 +1445,18 @@ def symbolic_trace( ...@@ -1307,6 +1445,18 @@ def symbolic_trace(
if not disable_check: if not disable_check:
check_if_model_is_supported(model) check_if_model_is_supported(model)
if "past_key_values" in input_names and not getattr(model.config, "use_cache", False):
logger.warning(
"`past_key_values` were specified as input names, but model.config.use_cache = False, this might lead to "
"unexpected behavior."
)
if "past_key_values" not in input_names and getattr(model.config, "use_cache", False):
logger.warning(
"`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting "
"model.config.use_cache = False."
)
model.config.use_cache = False
# Tracing. # Tracing.
tracer = tracer_cls() tracer = tracer_cls()
traced_graph = tracer.trace(model, concrete_args=concrete_args) traced_graph = tracer.trace(model, concrete_args=concrete_args)
......
...@@ -18,7 +18,6 @@ import gc ...@@ -18,7 +18,6 @@ import gc
import inspect import inspect
import os import os
import os.path import os.path
import pickle
import random import random
import re import re
import tempfile import tempfile
...@@ -1279,26 +1278,6 @@ class ModelTesterMixin: ...@@ -1279,26 +1278,6 @@ class ModelTesterMixin:
f"traced {i}th output doesn't match model {i}th output for {model_class}", f"traced {i}th output doesn't match model {i}th output for {model_class}",
) )
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB. # Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB) # (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry() self.clear_torch_jit_class_registry()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment