Commit 18e1f751 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

TF support

parent 31e5b5ff
...@@ -24,7 +24,8 @@ import os ...@@ -24,7 +24,8 @@ import os
import tensorflow as tf import tensorflow as tf
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME from .file_utils import (TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME,
cached_path, hf_bucket_url, is_remote_url)
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -257,12 +258,14 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -257,12 +258,14 @@ class TFPreTrainedModel(tf.keras.Model):
raise EnvironmentError("Error no file named {} found in directory {} or `from_pt` set to False".format( raise EnvironmentError("Error no file named {} found in directory {} or `from_pt` set to False".format(
[WEIGHTS_NAME, TF2_WEIGHTS_NAME], [WEIGHTS_NAME, TF2_WEIGHTS_NAME],
pretrained_model_name_or_path)) pretrained_model_name_or_path))
elif os.path.isfile(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path archive_file = pretrained_model_name_or_path
elif os.path.isfile(pretrained_model_name_or_path + ".index"): elif os.path.isfile(pretrained_model_name_or_path + ".index"):
archive_file = pretrained_model_name_or_path + ".index" archive_file = pretrained_model_name_or_path + ".index"
else: else:
archive_file = pretrained_model_name_or_path archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=TF2_WEIGHTS_NAME)
if from_pt:
raise EnvironmentError("Loading a TF model from a PyTorch checkpoint is not supported when using a model identifier name.")
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
......
...@@ -372,7 +372,8 @@ class PreTrainedModel(nn.Module): ...@@ -372,7 +372,8 @@ class PreTrainedModel(nn.Module):
archive_file = pretrained_model_name_or_path + ".index" archive_file = pretrained_model_name_or_path + ".index"
else: else:
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=WEIGHTS_NAME) archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=WEIGHTS_NAME)
# todo do we want to support TF checkpoints here? if from_tf:
raise EnvironmentError("Loading a PyTorch model from a TF checkpoint is not supported when using a model identifier name.")
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment