"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "a664b0c40e4664533d041b11b558b33575c6ec64"
Unverified Commit edd614ea authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

[CI] Fix TRANSFORMERS_FLAX_WEIGHTS_NAME import issue (#12354)

update
parent 7e7e62c6
...@@ -48,10 +48,12 @@ from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transfo ...@@ -48,10 +48,12 @@ from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transfo
if is_transformers_available(): if is_transformers_available():
import transformers import transformers
from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
if is_transformers_version("<=", "4.56.2"):
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
if is_accelerate_available(): if is_accelerate_available():
import accelerate import accelerate
from accelerate import dispatch_model from accelerate import dispatch_model
...@@ -112,7 +114,9 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No ...@@ -112,7 +114,9 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
] ]
if is_transformers_available(): if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
if is_transformers_version("<=", "4.56.2"):
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ... # model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names] weight_prefixes = [w.split(".")[0] for w in weight_names]
...@@ -191,7 +195,9 @@ def filter_model_files(filenames): ...@@ -191,7 +195,9 @@ def filter_model_files(filenames):
] ]
if is_transformers_available(): if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
if is_transformers_version("<=", "4.56.2"):
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
allowed_extensions = [wn.split(".")[-1] for wn in weight_names] allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
...@@ -212,7 +218,9 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) - ...@@ -212,7 +218,9 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
] ]
if is_transformers_available(): if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
if is_transformers_version("<=", "4.56.2"):
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ... # model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names] weight_prefixes = [w.split(".")[0] for w in weight_names]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment