Unverified Commit 6f7d5db5 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix transformers.utils.fx compatibility with torch<2.0 (#28774)

guard sdpa on torch>=2.0
parent 5c8d941d
...@@ -53,6 +53,7 @@ from ..models.auto.modeling_auto import ( ...@@ -53,6 +53,7 @@ from ..models.auto.modeling_auto import (
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES, MODEL_MAPPING_NAMES,
) )
from ..pytorch_utils import is_torch_greater_or_equal_than_2_0
from ..utils import ( from ..utils import (
ENV_VARS_TRUE_VALUES, ENV_VARS_TRUE_VALUES,
TORCH_FX_REQUIRED_VERSION, TORCH_FX_REQUIRED_VERSION,
...@@ -608,13 +609,17 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = { ...@@ -608,13 +609,17 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.Tensor.unsqueeze: torch_tensor_unsqueeze, torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
torch.unique_consecutive: torch_unique_consecutive, torch.unique_consecutive: torch_unique_consecutive,
torch.nn.functional.one_hot: torch_nn_functional_one_hot, torch.nn.functional.one_hot: torch_nn_functional_one_hot,
torch.nn.functional.scaled_dot_product_attention: torch_nn_functional_scaled_dot_product_attention,
torch.nn.MSELoss: torch_nn_mseloss, torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss, torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss, torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
operator.getitem: operator_getitem, operator.getitem: operator_getitem,
} }
if is_torch_greater_or_equal_than_2_0:
_MANUAL_META_OVERRIDES[
torch.nn.functional.scaled_dot_product_attention
] = torch_nn_functional_scaled_dot_product_attention
class HFProxy(Proxy): class HFProxy(Proxy):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment