Commit 18e1f751 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

TF support

parent 31e5b5ff
......@@ -24,7 +24,8 @@ import os
import tensorflow as tf
from .configuration_utils import PretrainedConfig
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
from .file_utils import (TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME,
cached_path, hf_bucket_url, is_remote_url)
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__)
......@@ -257,12 +258,14 @@ class TFPreTrainedModel(tf.keras.Model):
raise EnvironmentError("Error no file named {} found in directory {} or `from_pt` set to False".format(
[WEIGHTS_NAME, TF2_WEIGHTS_NAME],
pretrained_model_name_or_path))
elif os.path.isfile(pretrained_model_name_or_path):
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
archive_file = pretrained_model_name_or_path + ".index"
else:
archive_file = pretrained_model_name_or_path
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=TF2_WEIGHTS_NAME)
if from_pt:
raise EnvironmentError("Loading a TF model from a PyTorch checkpoint is not supported when using a model identifier name.")
# redirect to the cache, if necessary
try:
......
......@@ -372,7 +372,8 @@ class PreTrainedModel(nn.Module):
archive_file = pretrained_model_name_or_path + ".index"
else:
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=WEIGHTS_NAME)
# todo do we want to support TF checkpoints here?
if from_tf:
raise EnvironmentError("Loading a PyTorch model from a TF checkpoint is not supported when using a model identifier name.")
# redirect to the cache, if necessary
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment