Commit 270fa2f2 authored by thomwolf's avatar thomwolf
Browse files

add pretrained loading from state_dict

parent 174cdbcc
...@@ -445,9 +445,9 @@ class PreTrainedBertModel(nn.Module): ...@@ -445,9 +445,9 @@ class PreTrainedBertModel(nn.Module):
module.bias.data.zero_() module.bias.data.zero_()
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
""" """
Instantiate a PreTrainedBertModel from a pre-trained model file. Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
Params: Params:
...@@ -461,6 +461,8 @@ class PreTrainedBertModel(nn.Module): ...@@ -461,6 +461,8 @@ class PreTrainedBertModel(nn.Module):
- a path or url to a pretrained model archive containing: - a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model . `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class *inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification) (ex: num_labels for BertForSequenceClassification)
""" """
...@@ -502,8 +504,9 @@ class PreTrainedBertModel(nn.Module): ...@@ -502,8 +504,9 @@ class PreTrainedBertModel(nn.Module):
logger.info("Model config {}".format(config)) logger.info("Model config {}".format(config))
# Instantiate model. # Instantiate model.
model = cls(config, *inputs, **kwargs) model = cls(config, *inputs, **kwargs)
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) if state_dict is None:
state_dict = torch.load(weights_path) weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
state_dict = torch.load(weights_path)
missing_keys = [] missing_keys = []
unexpected_keys = [] unexpected_keys = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment