serialization.rst 3.86 KB
Newer Older
1
2
Serialization best-practices
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
thomwolf's avatar
thomwolf committed
3
4
5
6

This section explain how you can save and re-load a fine-tuned model (BERT, GPT, GPT-2 and Transformer-XL).
There are three types of files you need to save to be able to reload a fine-tuned model:

7

Julien Chaumond's avatar
Julien Chaumond committed
8
* the model itself which should be saved following PyTorch serialization `best practices <https://pytorch.org/docs/stable/notes/serialization.html#best-practices>`__\ ,
9
10
* the configuration file of the model which is saved as a JSON file, and
* the vocabulary (and the merges for the BPE-based models GPT and GPT-2).
thomwolf's avatar
thomwolf committed
11
12
13
14

The *default filenames* of these files are as follow:


15
16
17
18
19
20
* the model weights file: ``pytorch_model.bin``\ ,
* the configuration file: ``config.json``\ ,
* the vocabulary file: ``vocab.txt`` for BERT and Transformer-XL, ``vocab.json`` for GPT/GPT-2 (BPE vocabulary),
* for GPT/GPT-2 (BPE vocabulary) the additional merges file: ``merges.txt``.

**If you save a model using these *default filenames*\ , you can then re-load the model and tokenizer using the ``from_pretrained()`` method.**
thomwolf's avatar
thomwolf committed
21

22
Here is the recommended way of saving the model, configuration and vocabulary to an ``output_dir`` directory and reloading the model and tokenizer afterwards:
thomwolf's avatar
thomwolf committed
23

24
.. code-block:: python
thomwolf's avatar
thomwolf committed
25

26
   from transformers import WEIGHTS_NAME, CONFIG_NAME
thomwolf's avatar
thomwolf committed
27

28
   output_dir = "./models/"
thomwolf's avatar
thomwolf committed
29

30
   # Step 1: Save a model, configuration and vocabulary that you have fine-tuned
thomwolf's avatar
thomwolf committed
31

32
33
34
   # If we have a distributed model, save only the encapsulated model
   # (it was wrapped in PyTorch DistributedDataParallel or DataParallel)
   model_to_save = model.module if hasattr(model, 'module') else model
thomwolf's avatar
thomwolf committed
35

36
37
38
   # If we save using the predefined names, we can load using `from_pretrained`
   output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
   output_config_file = os.path.join(output_dir, CONFIG_NAME)
thomwolf's avatar
thomwolf committed
39

40
41
   torch.save(model_to_save.state_dict(), output_model_file)
   model_to_save.config.to_json_file(output_config_file)
42
   tokenizer.save_pretrained(output_dir)
thomwolf's avatar
thomwolf committed
43

44
45
46
47
   # Step 2: Re-load the saved model and vocabulary

   # Example for a Bert model
   model = BertForQuestionAnswering.from_pretrained(output_dir)
48
   tokenizer = BertTokenizer.from_pretrained(output_dir)  # Add specific options if needed
49
50
51
   # Example for a GPT model
   model = OpenAIGPTDoubleHeadsModel.from_pretrained(output_dir)
   tokenizer = OpenAIGPTTokenizer.from_pretrained(output_dir)
thomwolf's avatar
thomwolf committed
52
53
54

Here is another way you can save and reload the model if you want to use specific paths for each type of files:

55
56
57
58
59
60
61
.. code-block:: python

   output_model_file = "./models/my_own_model_file.bin"
   output_config_file = "./models/my_own_config_file.bin"
   output_vocab_file = "./models/my_own_vocab_file.bin"

   # Step 1: Save a model, configuration and vocabulary that you have fine-tuned
thomwolf's avatar
thomwolf committed
62

63
64
65
   # If we have a distributed model, save only the encapsulated model
   # (it was wrapped in PyTorch DistributedDataParallel or DataParallel)
   model_to_save = model.module if hasattr(model, 'module') else model
thomwolf's avatar
thomwolf committed
66

67
68
69
   torch.save(model_to_save.state_dict(), output_model_file)
   model_to_save.config.to_json_file(output_config_file)
   tokenizer.save_vocabulary(output_vocab_file)
thomwolf's avatar
thomwolf committed
70

71
   # Step 2: Re-load the saved model and vocabulary
thomwolf's avatar
thomwolf committed
72

73
74
   # We didn't save using the predefined WEIGHTS_NAME, CONFIG_NAME names, we cannot load using `from_pretrained`.
   # Here is how to do it in this situation:
thomwolf's avatar
thomwolf committed
75

76
77
78
79
80
81
   # Example for a Bert model
   config = BertConfig.from_json_file(output_config_file)
   model = BertForQuestionAnswering(config)
   state_dict = torch.load(output_model_file)
   model.load_state_dict(state_dict)
   tokenizer = BertTokenizer(output_vocab_file, do_lower_case=args.do_lower_case)
thomwolf's avatar
thomwolf committed
82

83
84
85
86
87
88
   # Example for a GPT model
   config = OpenAIGPTConfig.from_json_file(output_config_file)
   model = OpenAIGPTDoubleHeadsModel(config)
   state_dict = torch.load(output_model_file)
   model.load_state_dict(state_dict)
   tokenizer = OpenAIGPTTokenizer(output_vocab_file)
thomwolf's avatar
thomwolf committed
89