Commit 63ae5d21 authored by thomwolf's avatar thomwolf
Browse files

added cache_dir option in from_pretrained

parent 029bdc0d
...@@ -443,7 +443,7 @@ class PreTrainedBertModel(nn.Module): ...@@ -443,7 +443,7 @@ class PreTrainedBertModel(nn.Module):
module.bias.data.zero_() module.bias.data.zero_()
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
""" """
Instantiate a PreTrainedBertModel from a pre-trained model file. Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
...@@ -468,7 +468,7 @@ class PreTrainedBertModel(nn.Module): ...@@ -468,7 +468,7 @@ class PreTrainedBertModel(nn.Module):
archive_file = pretrained_model_name archive_file = pretrained_model_name
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError: except FileNotFoundError:
logger.error( logger.error(
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment