@@ -178,6 +177,7 @@ class TFPreTrainedModel(tf.keras.Model):
pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
...
...
@@ -263,12 +263,14 @@ class TFPreTrainedModel(tf.keras.Model):
raiseEnvironmentError("Error no file named {} found in directory {} or `from_pt` set to False".format(
@@ -71,6 +71,15 @@ class PreTrainedModel(nn.Module):
load_tf_weights=lambdamodel,config,path:None
base_model_prefix=""
@property
defdummy_inputs(self):
""" Dummy inputs to do a forward pass in the network.
Returns:
torch.Tensor with dummy inputs
"""
return{'input_ids':torch.tensor(DUMMY_INPUTS)}
def__init__(self,config,*inputs,**kwargs):
super(PreTrainedModel,self).__init__()
ifnotisinstance(config,PretrainedConfig):
...
...
@@ -160,8 +169,7 @@ class PreTrainedModel(nn.Module):
base_model.vocab_size=new_num_tokens
# Tie weights again if needed
ifhasattr(self,'tie_weights'):
self.tie_weights()
self.tie_weights()
returnmodel_embeds
...
...
@@ -265,6 +273,7 @@ class PreTrainedModel(nn.Module):
pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
...
...
@@ -318,10 +327,6 @@ class PreTrainedModel(nn.Module):
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)