Unverified Commit e58cbed5 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Revert #20715 (#26734)



* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent b219ae6b
...@@ -186,7 +186,9 @@ from .import_utils import ( ...@@ -186,7 +186,9 @@ from .import_utils import (
is_training_run_on_sagemaker, is_training_run_on_sagemaker,
is_vision_available, is_vision_available,
requires_backends, requires_backends,
tf_required,
torch_only_method, torch_only_method,
torch_required,
) )
from .peft_utils import ( from .peft_utils import (
ADAPTER_CONFIG_NAME, ADAPTER_CONFIG_NAME,
......
...@@ -24,7 +24,7 @@ import subprocess ...@@ -24,7 +24,7 @@ import subprocess
import sys import sys
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from functools import lru_cache from functools import lru_cache, wraps
from itertools import chain from itertools import chain
from types import ModuleType from types import ModuleType
from typing import Any, Tuple, Union from typing import Any, Tuple, Union
...@@ -1222,6 +1222,40 @@ class DummyObject(type): ...@@ -1222,6 +1222,40 @@ class DummyObject(type):
requires_backends(cls, cls._backends) requires_backends(cls, cls._backends)
def torch_required(func):
warnings.warn(
"The method `torch_required` is deprecated and will be removed in v4.36. Use `requires_backends` instead.",
FutureWarning,
)
# Chose a different decorator name than in tests so it's clear they are not the same.
@wraps(func)
def wrapper(*args, **kwargs):
if is_torch_available():
return func(*args, **kwargs)
else:
raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
return wrapper
def tf_required(func):
warnings.warn(
"The method `tf_required` is deprecated and will be removed in v4.36. Use `requires_backends` instead.",
FutureWarning,
)
# Chose a different decorator name than in tests so it's clear they are not the same.
@wraps(func)
def wrapper(*args, **kwargs):
if is_tf_available():
return func(*args, **kwargs)
else:
raise ImportError(f"Method `{func.__name__}` requires TF.")
return wrapper
def is_torch_fx_proxy(x): def is_torch_fx_proxy(x):
if is_torch_fx_available(): if is_torch_fx_available():
import torch.fx import torch.fx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment