Unverified Commit 6f7d5db5 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix transformers.utils.fx compatibility with torch<2.0 (#28774)

guard sdpa on torch>=2.0
parent 5c8d941d
......@@ -53,6 +53,7 @@ from ..models.auto.modeling_auto import (
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
from ..pytorch_utils import is_torch_greater_or_equal_than_2_0
from ..utils import (
ENV_VARS_TRUE_VALUES,
TORCH_FX_REQUIRED_VERSION,
......@@ -608,13 +609,17 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
torch.unique_consecutive: torch_unique_consecutive,
torch.nn.functional.one_hot: torch_nn_functional_one_hot,
torch.nn.functional.scaled_dot_product_attention: torch_nn_functional_scaled_dot_product_attention,
torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
operator.getitem: operator_getitem,
}
if is_torch_greater_or_equal_than_2_0:
_MANUAL_META_OVERRIDES[
torch.nn.functional.scaled_dot_product_attention
] = torch_nn_functional_scaled_dot_product_attention
class HFProxy(Proxy):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment