Unverified Commit edd614ea authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

[CI] Fix TRANSFORMERS_FLAX_WEIGHTS_NAME import issue (#12354)

update
parent 7e7e62c6
......@@ -48,10 +48,12 @@ from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transfo
if is_transformers_available():
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
if is_transformers_version("<=", "4.56.2"):
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
if is_accelerate_available():
import accelerate
from accelerate import dispatch_model
......@@ -112,7 +114,9 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
if is_transformers_version("<=", "4.56.2"):
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
......@@ -191,7 +195,9 @@ def filter_model_files(filenames):
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
if is_transformers_version("<=", "4.56.2"):
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
......@@ -212,7 +218,9 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
if is_transformers_version("<=", "4.56.2"):
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment