"...composable_kernel_rocm.git" did not exist on "48fe853266ca86f287a2248bfe99bccf84fa257f"
Unverified Commit 61d22f9c authored by Bram Vanroy's avatar Bram Vanroy Committed by GitHub
Browse files

Simplify cache vars and allow for TRANSFORMERS_CACHE env (#4226)

* simplify cache vars and allow for TRANSFORMERS_CACHE env

As it currently stands, "TRANSFORMERS_CACHE" is not an accepted variable. It seems that the these variables were not updated when moving from version pytorch_transformers to transformers. In addition, the fallback procedure could be improved. and simplified. Pathlib seems redundant here.

* Update file_utils.py
parent cd40cb88
...@@ -15,6 +15,7 @@ import tempfile ...@@ -15,6 +15,7 @@ import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial, wraps from functools import partial, wraps
from hashlib import sha256 from hashlib import sha256
from pathlib import Path
from typing import Optional from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile from zipfile import ZipFile, is_zipfile
...@@ -68,19 +69,10 @@ except ImportError: ...@@ -68,19 +69,10 @@ except ImportError:
) )
default_cache_path = os.path.join(torch_cache_home, "transformers") default_cache_path = os.path.join(torch_cache_home, "transformers")
try:
from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path(
os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path))
)
except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv(
"PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
)
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
TF2_WEIGHTS_NAME = "tf_model.h5" TF2_WEIGHTS_NAME = "tf_model.h5"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment