"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "8be46438be1c848e01e4085f54ae997e2e918771"
Commit 60ea6c59 authored by thomwolf's avatar thomwolf
Browse files

added best practices for serialization in README and examples

parent 179a2c2f
...@@ -525,6 +525,82 @@ model = GPT2Model.from_pretrained('gpt2') ...@@ -525,6 +525,82 @@ model = GPT2Model.from_pretrained('gpt2')
``` ```
### Serialization best-practices: saving and re-loading a fine-tuned model (BERT, GPT, GPT-2 and Transformer-XL)
There are three types of files you need to save to be able to reload a fine-tuned model:
- the model it-self which should be saved following PyTorch serialization [best practices](https://pytorch.org/docs/stable/notes/serialization.html#best-practices),
- the configuration file of the model which is saved as a JSON file, and
- the vocabulary (and the merges for the BPE-based models GPT and GPT-2).
Here is the recommended way of saving the model, configuration and vocabulary to an `output_dir` directory and reloading the model and tokenizer afterwards:
```python
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
output_dir = "./models/"
# Step 1: Save a model, configuration and vocabulary that you have fine-tuned
# If we have a distributed model, save only the encapsulated model
# (it was wrapped in PyTorch DistributedDataParallel or DataParallel)
model_to_save = model.module if hasattr(model, 'module') else model
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(output_dir, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_dir)
# Step 2: Re-load the saved model and vocabulary
# Example for a Bert model
model = BertForQuestionAnswering.from_pretrained(output_dir)
tokenizer = BertTokenizer.from_pretrained(output_dir, do_lower_case=args.do_lower_case) # Add specific options if needed
# Example for a GPT model
model = OpenAIGPTDoubleHeadsModel.from_pretrained(output_dir)
tokenizer = OpenAIGPTTokenizer.from_pretrained(output_dir)
```
Here is another way you can save and reload the model if you want to use specific paths for each type of files:
```python
output_model_file = "./models/my_own_model_file.bin"
output_config_file = "./models/my_own_config_file.bin"
output_vocab_file = "./models/my_own_vocab_file.bin"
# Step 1: Save a model, configuration and vocabulary that you have fine-tuned
# If we have a distributed model, save only the encapsulated model
# (it was wrapped in PyTorch DistributedDataParallel or DataParallel)
model_to_save = model.module if hasattr(model, 'module') else model
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_vocab_file)
# Step 2: Re-load the saved model and vocabulary
# We didn't save using the predefined WEIGHTS_NAME, CONFIG_NAME names, we cannot load using `from_pretrained`.
# Here is how to do it in this situation:
# Example for a Bert model
config = BertConfig.from_json_file(output_config_file)
model = BertForQuestionAnswering(config)
state_dict = torch.load(output_model_file)
model.load_state_dict(state_dict)
tokenizer = BertTokenizer(output_vocab_file, do_lower_case=args.do_lower_case)
# Example for a GPT model
config = OpenAIGPTConfig.from_json_file(output_config_file)
model = OpenAIGPTDoubleHeadsModel(config)
state_dict = torch.load(output_model_file)
model.load_state_dict(state_dict)
tokenizer = OpenAIGPTTokenizer(output_vocab_file)
```
### Configuration classes ### Configuration classes
Models (BERT, GPT, GPT-2 and Transformer-XL) are defined and build from configuration classes which containes the parameters of the models (number of layers, dimensionalities...) and a few utilities to read and write from JSON configuration files. The respective configuration classes are: Models (BERT, GPT, GPT-2 and Transformer-XL) are defined and build from configuration classes which containes the parameters of the models (number of layers, dimensionalities...) and a few utilities to read and write from JSON configuration files. The respective configuration classes are:
......
...@@ -35,9 +35,9 @@ from torch.nn import CrossEntropyLoss, MSELoss ...@@ -35,9 +35,9 @@ from torch.nn import CrossEntropyLoss, MSELoss
from scipy.stats import pearsonr, spearmanr from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score from sklearn.metrics import matthews_corrcoef, f1_score
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig
from pytorch_pretrained_bert.tokenization import BertTokenizer, VOCAB_NAME from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
...@@ -863,15 +863,14 @@ def main(): ...@@ -863,15 +863,14 @@ def main():
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME) output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
output_vocab_file = os.path.join(args.output_dir, VOCAB_NAME)
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file) model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_vocab_file) tokenizer.save_vocabulary(args.output_dir)
# Load a trained model and vocabulary that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels) model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
tokenizer = BertTokenizer.from_pretrained(args.output_dir) tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
else: else:
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels) model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
model.to(device) model.to(device)
......
...@@ -39,8 +39,8 @@ import torch ...@@ -39,8 +39,8 @@ import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset) TensorDataset)
from pytorch_pretrained_bert import OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, OpenAIAdam, cached_path from pytorch_pretrained_bert import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer,
from pytorch_pretrained_bert.modeling_openai import WEIGHTS_NAME, CONFIG_NAME OpenAIAdam, cached_path, WEIGHTS_NAME, CONFIG_NAME)
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz" ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
......
...@@ -34,12 +34,12 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, ...@@ -34,12 +34,12 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig, WEIGHTS_NAME, CONFIG_NAME from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from pytorch_pretrained_bert.tokenization import (BasicTokenizer, from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
BertTokenizer, BertTokenizer,
whitespace_tokenize, VOCAB_NAME) whitespace_tokenize)
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
...@@ -1015,15 +1015,14 @@ def main(): ...@@ -1015,15 +1015,14 @@ def main():
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME) output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
output_vocab_file = os.path.join(args.output_dir, VOCAB_NAME)
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file) model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_vocab_file) tokenizer.save_vocabulary(args.output_dir)
# Load a trained model and vocabulary that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
model = BertForQuestionAnswering.from_pretrained(args.output_dir) model = BertForQuestionAnswering.from_pretrained(args.output_dir)
tokenizer = BertTokenizer.from_pretrained(args.output_dir) tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
else: else:
model = BertForQuestionAnswering.from_pretrained(args.bert_model) model = BertForQuestionAnswering.from_pretrained(args.bert_model)
......
...@@ -32,10 +32,10 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, ...@@ -32,10 +32,10 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.modeling import (BertForMultipleChoice, BertConfig, WEIGHTS_NAME, CONFIG_NAME) from pytorch_pretrained_bert.modeling import BertForMultipleChoice, BertConfig
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from pytorch_pretrained_bert.tokenization import BertTokenizer, VOCAB_NAME from pytorch_pretrained_bert.tokenization import BertTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -479,15 +479,14 @@ def main(): ...@@ -479,15 +479,14 @@ def main():
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME) output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
output_vocab_file = os.path.join(args.output_dir, VOCAB_NAME)
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file) model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_vocab_file) tokenizer.save_vocabulary(args.output_dir)
# Load a trained model and vocabulary that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
model = BertForMultipleChoice.from_pretrained(args.output_dir, num_choices=4) model = BertForMultipleChoice.from_pretrained(args.output_dir, num_choices=4)
tokenizer = BertTokenizer.from_pretrained(args.output_dir) tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
else: else:
model = BertForMultipleChoice.from_pretrained(args.bert_model, num_choices=4) model = BertForMultipleChoice.from_pretrained(args.bert_model, num_choices=4)
model.to(device) model.to(device)
......
...@@ -21,4 +21,4 @@ from .modeling_gpt2 import (GPT2Config, GPT2Model, ...@@ -21,4 +21,4 @@ from .modeling_gpt2 import (GPT2Config, GPT2Model,
from .optimization import BertAdam from .optimization import BertAdam
from .optimization_openai import OpenAIAdam from .optimization_openai import OpenAIAdam
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME
...@@ -33,6 +33,9 @@ except (AttributeError, ImportError): ...@@ -33,6 +33,9 @@ except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
......
...@@ -32,7 +32,7 @@ import torch ...@@ -32,7 +32,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from .file_utils import cached_path from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -45,8 +45,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -45,8 +45,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
} }
CONFIG_NAME = 'bert_config.json' BERT_CONFIG_NAME = 'bert_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = 'model.ckpt'
def load_tf_weights_in_bert(model, tf_checkpoint_path): def load_tf_weights_in_bert(model, tf_checkpoint_path):
...@@ -586,6 +585,9 @@ class BertPreTrainedModel(nn.Module): ...@@ -586,6 +585,9 @@ class BertPreTrainedModel(nn.Module):
serialization_dir = tempdir serialization_dir = tempdir
# Load config # Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME) config_file = os.path.join(serialization_dir, CONFIG_NAME)
if not os.path.exists(config_file):
# Backward compatibility with old naming format
config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
config = BertConfig.from_json_file(config_file) config = BertConfig.from_json_file(config_file)
logger.info("Model config {}".format(config)) logger.info("Model config {}".format(config))
# Instantiate model. # Instantiate model.
......
...@@ -34,7 +34,7 @@ import torch.nn as nn ...@@ -34,7 +34,7 @@ import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .file_utils import cached_path from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
from .modeling import BertLayerNorm as LayerNorm from .modeling import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -42,9 +42,6 @@ logger = logging.getLogger(__name__) ...@@ -42,9 +42,6 @@ logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"} PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"} PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"}
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path): def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
""" Load tf checkpoints in a pytorch model """ Load tf checkpoints in a pytorch model
""" """
......
...@@ -34,7 +34,7 @@ import torch.nn as nn ...@@ -34,7 +34,7 @@ import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .file_utils import cached_path from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
from .modeling import BertLayerNorm as LayerNorm from .modeling import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -42,8 +42,6 @@ logger = logging.getLogger(__name__) ...@@ -42,8 +42,6 @@ logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"} PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"} PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path): def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here) """ Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
......
...@@ -40,7 +40,7 @@ from torch.nn.parameter import Parameter ...@@ -40,7 +40,7 @@ from torch.nn.parameter import Parameter
from .modeling import BertLayerNorm as LayerNorm from .modeling import BertLayerNorm as LayerNorm
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
from .file_utils import cached_path from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -50,8 +50,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -50,8 +50,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
PRETRAINED_CONFIG_ARCHIVE_MAP = { PRETRAINED_CONFIG_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json", 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
} }
CONFIG_NAME = 'config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = 'model.ckpt'
def build_tf_to_pytorch_map(model, config): def build_tf_to_pytorch_map(model, config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment