Unverified Commit b6a65ae5 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix circular import in onnx.utils (#17577)

* Fix circular import in onnx.utils

* Add comment for test fetcher

* Here too

* Style
parent 9aa230aa
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
from ctypes import c_float, sizeof from ctypes import c_float, sizeof
from enum import Enum from enum import Enum
from typing import Optional, Union from typing import TYPE_CHECKING, Optional, Union
from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
if TYPE_CHECKING:
from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer # tests_ignore
class ParameterFormat(Enum): class ParameterFormat(Enum):
...@@ -66,7 +68,7 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm ...@@ -66,7 +68,7 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm
return num_parameters * dtype.size return num_parameters * dtype.size
def get_preprocessor(model_name: str) -> Optional[Union[AutoTokenizer, AutoFeatureExtractor, AutoProcessor]]: def get_preprocessor(model_name: str) -> Optional[Union["AutoTokenizer", "AutoFeatureExtractor", "AutoProcessor"]]:
""" """
Gets a preprocessor (tokenizer, feature extractor or processor) that is available for `model_name`. Gets a preprocessor (tokenizer, feature extractor or processor) that is available for `model_name`.
...@@ -79,6 +81,9 @@ def get_preprocessor(model_name: str) -> Optional[Union[AutoTokenizer, AutoFeatu ...@@ -79,6 +81,9 @@ def get_preprocessor(model_name: str) -> Optional[Union[AutoTokenizer, AutoFeatu
returned. If both a tokenizer and a feature extractor exist, an error is raised. The function returns returned. If both a tokenizer and a feature extractor exist, an error is raised. The function returns
`None` if no preprocessor is found. `None` if no preprocessor is found.
""" """
# Avoid circular imports by only importing this here.
from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer # tests_ignore
try: try:
return AutoProcessor.from_pretrained(model_name) return AutoProcessor.from_pretrained(model_name)
except (ValueError, OSError, KeyError): except (ValueError, OSError, KeyError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment