Commit 3b7fb48c authored by thomwolf's avatar thomwolf
Browse files

fix loading from tf/pt

parent a049c804
...@@ -27,7 +27,8 @@ tf_model.save_pretrained('./runs/') ...@@ -27,7 +27,8 @@ tf_model.save_pretrained('./runs/')
pt_model = BertForSequenceClassification.from_pretrained('./runs/') pt_model = BertForSequenceClassification.from_pretrained('./runs/')
# Quickly inspect a few predictions # Quickly inspect a few predictions
inputs = tokenizer.encode_plus("I said the company is doing great", "The company has good results", add_special_tokens=True)
pred = pt_model(torch.tensor([tokens]))
# Divers # Divers
import torch import torch
......
...@@ -224,8 +224,8 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -224,8 +224,8 @@ class TFPreTrainedModel(tf.keras.Model):
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else: else:
raise EnvironmentError("Error no file named {} found in directory {}".format( raise EnvironmentError("Error no file named {} found in directory {} or `from_pt` set to False".format(
tuple(WEIGHTS_NAME, TF2_WEIGHTS_NAME), [WEIGHTS_NAME, TF2_WEIGHTS_NAME],
pretrained_model_name_or_path)) pretrained_model_name_or_path))
elif os.path.isfile(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path archive_file = pretrained_model_name_or_path
......
...@@ -304,7 +304,7 @@ class PreTrainedModel(nn.Module): ...@@ -304,7 +304,7 @@ class PreTrainedModel(nn.Module):
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else: else:
raise EnvironmentError("Error no file named {} found in directory {}".format( raise EnvironmentError("Error no file named {} found in directory {} or `from_tf` set to False".format(
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"], [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
pretrained_model_name_or_path)) pretrained_model_name_or_path))
elif os.path.isfile(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment