Unverified Commit fbabd674 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

Fix for Neuron (#30259)

parent 5cf3e6bf
......@@ -1010,8 +1010,11 @@ class CohereModel(CoherePreTrainedModel):
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
......
......@@ -1001,8 +1001,11 @@ class GemmaModel(GemmaPreTrainedModel):
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
......
......@@ -1089,8 +1089,11 @@ class LlamaModel(LlamaPreTrainedModel):
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
......
......@@ -1068,8 +1068,11 @@ class OlmoModel(OlmoPreTrainedModel):
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
......
......@@ -84,12 +84,12 @@ if is_torch_neuroncore_available(check_device=False):
if os.environ.get("TORCHELASTIC_RUN_ID"):
if is_optimum_neuron_available():
logger.info(
"Make sure that you are performing the training with the TrainiumTrainer from optimum[neuron], this "
"Make sure that you are performing the training with the NeuronTrainer from optimum[neuron], this "
"will fail otherwise."
)
else:
logger.warning(
"Please use the TrainiumTrainer from optimum[neuron] instead of the Transformers library to perform "
"Please use the NeuronTrainer from optimum[neuron] instead of the Transformers library to perform "
"training on AWS Trainium instances. More information here: "
"https://github.com/huggingface/optimum-neuron"
)
......
......@@ -15,22 +15,28 @@
import builtins
import collections
import contextlib
import functools
import inspect
import math
import operator
import os
import random
import sys
import warnings
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
import torch
import torch.utils._pytree as pytree
from torch import nn
from torch.fx import Graph, GraphModule, Proxy, Tracer
from torch.fx import Graph, GraphModule, Node, Proxy, Tracer
from torch.fx._compatibility import compatibility
from torch.fx._symbolic_trace import is_fx_tracing
from torch.fx.proxy import ParameterProxy
from .. import PretrainedConfig, PreTrainedModel, logging
from .. import logging
from ..cache_utils import Cache, DynamicCache, SinkCache, StaticCache
from ..modeling_utils import PretrainedConfig, PreTrainedModel
from ..models.auto import get_values
from ..models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
......@@ -55,7 +61,7 @@ from ..models.auto.modeling_auto import (
MODEL_MAPPING_NAMES,
)
from ..pytorch_utils import is_torch_greater_or_equal_than_2_0
from ..utils import (
from .import_utils import (
ENV_VARS_TRUE_VALUES,
TORCH_FX_REQUIRED_VERSION,
get_torch_version,
......@@ -192,6 +198,8 @@ _SPECIAL_SUPPORTED_MODELS = [
]
_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
_CURRENT_TRACER = None
def torch_nn_embedding(self, input):
return torch.empty(*input.shape, self.weight.shape[-1], device="meta", dtype=self.weight.dtype)
......@@ -701,6 +709,92 @@ class MetaDeviceAttribute(HFAttribute):
pass
class HFCacheProxy(HFProxy):
"""
Proxy that represents an instance of `transformers.cache_utils.Cache`.
"""
@property
def __class__(self):
return ProxyableCache
def create_wrapper(
function: Callable,
op_type: Union[Literal["call_function"], Literal["call_method"], Literal["get_attr"]],
proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None,
) -> Callable:
@functools.wraps(function)
def wrapper(*args, **kwargs):
if not is_fx_tracing():
return function(*args, **kwargs)
found_proxies = []
def check_proxy(a):
if isinstance(a, Proxy):
found_proxies.append(a)
torch.fx.node.map_aggregate(args, check_proxy)
torch.fx.node.map_aggregate(kwargs, check_proxy)
if len(found_proxies) > 0:
tracer = found_proxies[0].tracer
if op_type == "call_function":
target = function
elif op_type == "call_method":
target = function.__name__
elif op_type == "get_attr":
target = function.__name__
else:
raise ValueError(f"op_type {op_type} not supported.")
return tracer.create_proxy(op_type, target, args, kwargs, proxy_factory_fn=proxy_factory_fn)
else:
return function(*args, **kwargs)
return wrapper
class HFProxyableClassMeta(type):
"""
Metaclass that creates a class with its main methods wrapped to be proxyable.
"""
def __new__(
cls,
name: str,
bases: Tuple[Type, ...],
attrs: Dict[str, Any],
proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None,
):
cls = super().__new__(cls, name, bases, attrs)
for attr_name in dir(cls):
attr = getattr(cls, attr_name, None)
if attr is None:
continue
if attr_name == "__init__":
op_type = "call_function"
elif attr_name.startswith("__"):
op_type = None
elif inspect.ismethod(attr):
op_type = "call_function"
elif inspect.isfunction(attr):
op_type = "call_method"
else:
op_type = None
if op_type is not None:
setattr(cls, attr_name, create_wrapper(attr, op_type, proxy_factory_fn=proxy_factory_fn))
return cls
def gen_constructor_wrapper(target: Callable) -> Tuple[Callable, Callable]:
"""
Wraps `target` to be proxyable. Used for tensor creators like `torch.ones`, `torch.arange` and so on.
"""
wrapper = create_wrapper(target, "call_function")
return wrapper, target
def _proxies_to_metas(v):
"""Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
if isinstance(v, MetaDeviceAttribute):
......@@ -712,25 +806,24 @@ def _proxies_to_metas(v):
return v
def _gen_constructor_wrapper(target):
@functools.wraps(target)
def wrapper(*args, **kwargs):
proxy = None
def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
global _CURRENT_TRACER
if not isinstance(_CURRENT_TRACER, HFTracer):
raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
return HFCacheProxy(n, _CURRENT_TRACER)
def check_has_proxy(v):
if isinstance(v, Proxy):
nonlocal proxy
proxy = v
torch.fx.node.map_aggregate(args, check_has_proxy)
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
if proxy is not None:
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
else:
return target(*args, **kwargs)
return wrapper, target
# Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
ProxyableCache = HFProxyableClassMeta("ProxyableCache", (Cache,), {}, proxy_factory_fn=cache_proxy_factory_fn)
ProxyableDynamicCache = HFProxyableClassMeta(
"ProxyableDynamicCache", (DynamicCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
)
ProxyableSinkCache = HFProxyableClassMeta(
"ProxyableSinkCache", (SinkCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
)
ProxyableStaticCache = HFProxyableClassMeta(
"ProxyableStaticCache", (StaticCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
)
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
......@@ -764,6 +857,13 @@ class HFTracer(Tracer):
"finfo",
"tril",
]
_CLASSES_TO_PATCH = {
Cache: ProxyableCache,
DynamicCache: ProxyableDynamicCache,
SinkCache: ProxyableSinkCache,
StaticCache: ProxyableStaticCache,
}
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
......@@ -776,7 +876,7 @@ class HFTracer(Tracer):
)
def _generate_dummy_input(
self, model: PreTrainedModel, input_name: str, shape: List[int], input_names: List[str]
self, model: "PreTrainedModel", input_name: str, shape: List[int], input_names: List[str]
) -> Dict[str, torch.Tensor]:
"""Generates dummy input for model inference recording."""
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
......@@ -951,6 +1051,11 @@ class HFTracer(Tracer):
args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas)
kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas)
should_install_metadata = True
self._disable_module_getattr = True
self._disable_call_module = True
if kind == "call_function":
meta_target = _MANUAL_META_OVERRIDES.get(target, target)
meta_out = meta_target(*args_metas, **kwargs_metas)
......@@ -963,39 +1068,36 @@ class HFTracer(Tracer):
elif kind == "call_module":
if not hasattr(self, "orig_forward"):
raise AttributeError(f"{self} does not have an attribute called orig_forward")
self._disable_module_getattr = True
try:
mod = self.root.get_submodule(target)
mod_type = type(mod)
if mod_type in _MANUAL_META_OVERRIDES:
meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
else:
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
finally:
self._disable_module_getattr = False
mod = self.root.get_submodule(target)
mod_type = type(mod)
if mod_type in _MANUAL_META_OVERRIDES:
meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
else:
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
elif kind == "get_attr":
self._disable_module_getattr = True
try:
attr_itr = self.root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.Tensor):
meta_out = attr_itr.to(device="meta")
else:
meta_out = attr_itr
finally:
self._disable_module_getattr = False
attr_itr = self.root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.Tensor):
meta_out = attr_itr.to(device="meta")
else:
meta_out = attr_itr
else:
return rv
should_install_metadata = False
if should_install_metadata:
if not isinstance(rv, Proxy):
raise ValueError("Don't support composite output yet")
rv.install_metadata(meta_out)
if not isinstance(rv, Proxy):
raise ValueError("Don't support composite output yet")
rv.install_metadata(meta_out)
except Exception as e:
if _IS_IN_DEBUG_MODE:
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
self._disable_module_getattr = False
self._disable_call_module = False
return rv
# Replaced by .getattr from PyTorch 1.13
......@@ -1041,12 +1143,51 @@ class HFTracer(Tracer):
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
def call_module(self, m, forward, args, kwargs):
if getattr(self, "_disable_call_module", False):
return forward(*args, **kwargs)
self.orig_forward = forward
return super().call_module(m, forward, args, kwargs)
def proxy(self, node):
return HFProxy(node, self)
@contextlib.contextmanager
def patch_for_tracing(self, root: Union[torch.nn.Module, Callable[..., Any]]):
# Patching torch functions
self.patched_torch_methods = {
target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
}
self.orig_fns = set()
for name, (wrapper, orig) in self.patched_torch_methods.items():
setattr(torch, name, wrapper)
self.orig_fns.add(orig)
# Patching classes
patched = []
module_of_model = inspect.getmodule(root)
for name, mod in sys.modules.items():
if module_of_model is not None and mod is not module_of_model:
continue
if not name.startswith("transformers"):
continue
for orig_cls, patched_cls in self._CLASSES_TO_PATCH.items():
for attr_name, attr in mod.__dict__.items():
if attr is orig_cls:
patched.append((mod, attr_name, orig_cls))
setattr(mod, attr_name, patched_cls)
yield
# Restoring patched functions and classes.
for name, (_, orig) in self.patched_torch_methods.items():
setattr(torch, name, orig)
self.patched_torch_methods = {}
self.orig_fns = set()
for mod, attr_name, orig_cls in patched:
setattr(mod, attr_name, orig_cls)
def trace(
self,
root: Union[torch.nn.Module, Callable[..., Any]],
......@@ -1125,28 +1266,25 @@ class HFTracer(Tracer):
" transformers.PreTrainedModel."
)
concrete_metas = {
input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_
for input_name, input_ in inputs.items()
}
def to_meta(value):
if isinstance(value, torch.Tensor):
return value.to("meta")
return value
concrete_metas = pytree.tree_map(to_meta, inputs)
for param in sig.parameters.values():
if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
concrete_metas[f"**{param.name}"] = {}
self.meta_args = concrete_metas
self.patched_torch_methods = {
target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
}
self.orig_fns = set()
for name, (wrapper, orig) in self.patched_torch_methods.items():
setattr(torch, name, wrapper)
self.orig_fns.add(orig)
try:
self.graph = super().trace(root, concrete_args=concrete_args)
finally:
for name, (_, orig) in self.patched_torch_methods.items():
setattr(torch, name, orig)
global _CURRENT_TRACER
_CURRENT_TRACER = self
with self.patch_for_tracing(root):
try:
self.graph = super().trace(root, concrete_args=concrete_args)
finally:
_CURRENT_TRACER = None
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
......@@ -1256,11 +1394,11 @@ def get_concrete_args(model: nn.Module, input_names: List[str]):
return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
def is_model_supported(model: PreTrainedModel):
def is_model_supported(model: "PreTrainedModel"):
return model.__class__.__name__ in _SUPPORTED_MODELS
def check_if_model_is_supported(model: PreTrainedModel):
def check_if_model_is_supported(model: "PreTrainedModel"):
if not is_model_supported(model):
supported_model_names = ", ".join(_SUPPORTED_MODELS)
raise NotImplementedError(
......@@ -1269,7 +1407,7 @@ def check_if_model_is_supported(model: PreTrainedModel):
def symbolic_trace(
model: PreTrainedModel,
model: "PreTrainedModel",
input_names: Optional[List[str]] = None,
disable_check: bool = False,
tracer_cls: Type[HFTracer] = HFTracer,
......@@ -1307,6 +1445,18 @@ def symbolic_trace(
if not disable_check:
check_if_model_is_supported(model)
if "past_key_values" in input_names and not getattr(model.config, "use_cache", False):
logger.warning(
"`past_key_values` were specified as input names, but model.config.use_cache = False, this might lead to "
"unexpected behavior."
)
if "past_key_values" not in input_names and getattr(model.config, "use_cache", False):
logger.warning(
"`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting "
"model.config.use_cache = False."
)
model.config.use_cache = False
# Tracing.
tracer = tracer_cls()
traced_graph = tracer.trace(model, concrete_args=concrete_args)
......
......@@ -18,7 +18,6 @@ import gc
import inspect
import os
import os.path
import pickle
import random
import re
import tempfile
......@@ -1279,26 +1278,6 @@ class ModelTesterMixin:
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment