Commit 29570db2 authored by thomwolf's avatar thomwolf
Browse files

allowing from_pretrained to load from url directly

parent 2e2f9fed
......@@ -259,8 +259,10 @@ class TFPreTrainedModel(tf.keras.Model):
pretrained_model_name_or_path))
elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
archive_file = pretrained_model_name_or_path + ".index"
else:
raise EnvironmentError("Error file {} not found".format(pretrained_model_name_or_path))
archive_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:
......
......@@ -365,9 +365,12 @@ class PreTrainedModel(nn.Module):
pretrained_model_name_or_path))
elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
else:
assert from_tf, "Error finding file {}, no file or TF 1.X checkpoint found".format(pretrained_model_name_or_path)
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
assert from_tf, "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
pretrained_model_name_or_path + ".index")
archive_file = pretrained_model_name_or_path + ".index"
else:
archive_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment