Unverified Commit 47b91651 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Remove imports and use forward references in ONNX feature (#17926)

parent 5cdfff5d
from functools import partial, reduce
from typing import Callable, Dict, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, Union
import transformers
from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available
from .. import PretrainedConfig, is_tf_available, is_torch_available
from ..utils import logging
from .config import OnnxConfig
if TYPE_CHECKING:
from transformers import PreTrainedModel, TFPreTrainedModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_torch_available():
......@@ -505,7 +509,7 @@ class FeaturesManager:
@staticmethod
def get_model_from_feature(
feature: str, model: str, framework: str = "pt", cache_dir: str = None
) -> Union[PreTrainedModel, TFPreTrainedModel]:
) -> Union["PreTrainedModel", "TFPreTrainedModel"]:
"""
Attempts to retrieve a model from a model's name and the feature to be enabled.
......@@ -533,7 +537,7 @@ class FeaturesManager:
@staticmethod
def check_supported_model_or_raise(
model: Union[PreTrainedModel, TFPreTrainedModel], feature: str = "default"
model: Union["PreTrainedModel", "TFPreTrainedModel"], feature: str = "default"
) -> Tuple[str, Callable]:
"""
Check whether or not the model has the requested features.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment