Commit 3963d57c authored by Ailing Zhang's avatar Ailing Zhang
Browse files

move pytroch_pretrained_bert cache folder under same path as torch

parent b832d5bb
...@@ -84,7 +84,7 @@ def bertTokenizer(*args, **kwargs): ...@@ -84,7 +84,7 @@ def bertTokenizer(*args, **kwargs):
Example: Example:
>>> sentence = 'Hello, World!' >>> sentence = 'Hello, World!'
>>> tokenizer = torch.hub.load('ailzhang/pytorch-pretrained-BERT:hubconf', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False) >>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT:hubconf', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False)
>>> toks = tokenizer.tokenize(sentence) >>> toks = tokenizer.tokenize(sentence)
['Hello', '##,', 'World', '##!'] ['Hello', '##,', 'World', '##!']
>>> ids = tokenizer.convert_tokens_to_ids(toks) >>> ids = tokenizer.convert_tokens_to_ids(toks)
......
...@@ -22,6 +22,15 @@ import requests ...@@ -22,6 +22,15 @@ import requests
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from tqdm import tqdm from tqdm import tqdm
try:
from torch.hub import _get_torch_home
torch_cache_home = _get_torch_home()
except ImportError:
torch_cache_home = os.path.expanduser(
os.getenv('TORCH_HOME', os.path.join(
os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
default_cache_path = os.path.join(torch_cache_home, 'pytorch_pretrained_bert')
try: try:
from urllib.parse import urlparse from urllib.parse import urlparse
except ImportError: except ImportError:
...@@ -29,11 +38,11 @@ except ImportError: ...@@ -29,11 +38,11 @@ except ImportError:
try: try:
from pathlib import Path from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', PYTORCH_PRETRAINED_BERT_CACHE = Path(
Path.home() / '.pytorch_pretrained_bert')) os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))
except (AttributeError, ImportError): except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) default_cache_path)
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment