Unverified Commit 61d22f9c authored by Bram Vanroy's avatar Bram Vanroy Committed by GitHub
Browse files

Simplify cache vars and allow for TRANSFORMERS_CACHE env (#4226)

* simplify cache vars and allow for TRANSFORMERS_CACHE env

As it currently stands, "TRANSFORMERS_CACHE" is not an accepted variable. It seems that the these variables were not updated when moving from version pytorch_transformers to transformers. In addition, the fallback procedure could be improved. and simplified. Pathlib seems redundant here.

* Update file_utils.py
parent cd40cb88
......@@ -15,6 +15,7 @@ import tempfile
from contextlib import contextmanager
from functools import partial, wraps
from hashlib import sha256
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile
......@@ -68,19 +69,10 @@ except ImportError:
)
default_cache_path = os.path.join(torch_cache_home, "transformers")
try:
from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path(
os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path))
)
except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv(
"PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
)
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
WEIGHTS_NAME = "pytorch_model.bin"
TF2_WEIGHTS_NAME = "tf_model.h5"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment