Commit 179a2c2f authored by thomwolf's avatar thomwolf
Browse files

update example to work with new serialization semantic

parent b3c6ee0a
...@@ -37,7 +37,7 @@ from sklearn.metrics import matthews_corrcoef, f1_score ...@@ -37,7 +37,7 @@ from sklearn.metrics import matthews_corrcoef, f1_score
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer, VOCAB_NAME
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
...@@ -857,18 +857,21 @@ def main(): ...@@ -857,18 +857,21 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
# Save a trained model and the associated configuration # Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME) output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
with open(output_config_file, 'w') as f: output_vocab_file = os.path.join(args.output_dir, VOCAB_NAME)
f.write(model_to_save.config.to_json_string())
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_vocab_file)
# Load a trained model and config that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
config = BertConfig(output_config_file) model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
model = BertForSequenceClassification(config, num_labels=num_labels) tokenizer = BertTokenizer.from_pretrained(args.output_dir)
model.load_state_dict(torch.load(output_model_file))
else: else:
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels) model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
model.to(device) model.to(device)
......
...@@ -40,6 +40,7 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, ...@@ -40,6 +40,7 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset) TensorDataset)
from pytorch_pretrained_bert import OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, OpenAIAdam, cached_path from pytorch_pretrained_bert import OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, OpenAIAdam, cached_path
from pytorch_pretrained_bert.modeling_openai import WEIGHTS_NAME, CONFIG_NAME
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz" ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
...@@ -218,15 +219,20 @@ def main(): ...@@ -218,15 +219,20 @@ def main():
# Save a trained model # Save a trained model
if args.do_train: if args.do_train:
# Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
config = model.config # If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(args.output_dir)
# Load a trained model that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
model_state_dict = torch.load(output_model_file) model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.output_dir)
model = OpenAIGPTDoubleHeadsModel(config) tokenizer = OpenAIGPTTokenizer.from_pretrained(args.output_dir)
model.load_state_dict(model_state_dict)
model.to(device) model.to(device)
if args.do_eval: if args.do_eval:
......
...@@ -39,7 +39,7 @@ from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfi ...@@ -39,7 +39,7 @@ from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfi
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from pytorch_pretrained_bert.tokenization import (BasicTokenizer, from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
BertTokenizer, BertTokenizer,
whitespace_tokenize) whitespace_tokenize, VOCAB_NAME)
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
...@@ -1009,18 +1009,21 @@ def main(): ...@@ -1009,18 +1009,21 @@ def main():
global_step += 1 global_step += 1
if args.do_train: if args.do_train:
# Save a trained model and the associated configuration # Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME) output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
with open(output_config_file, 'w') as f: output_vocab_file = os.path.join(args.output_dir, VOCAB_NAME)
f.write(model_to_save.config.to_json_string())
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_vocab_file)
# Load a trained model and config that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
config = BertConfig(output_config_file) model = BertForQuestionAnswering.from_pretrained(args.output_dir)
model = BertForQuestionAnswering(config) tokenizer = BertTokenizer.from_pretrained(args.output_dir)
model.load_state_dict(torch.load(output_model_file))
else: else:
model = BertForQuestionAnswering.from_pretrained(args.bert_model) model = BertForQuestionAnswering.from_pretrained(args.bert_model)
......
...@@ -35,7 +35,7 @@ from tqdm import tqdm, trange ...@@ -35,7 +35,7 @@ from tqdm import tqdm, trange
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import (BertForMultipleChoice, BertConfig, WEIGHTS_NAME, CONFIG_NAME) from pytorch_pretrained_bert.modeling import (BertForMultipleChoice, BertConfig, WEIGHTS_NAME, CONFIG_NAME)
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer, VOCAB_NAME
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -473,18 +473,21 @@ def main(): ...@@ -473,18 +473,21 @@ def main():
if args.do_train: if args.do_train:
# Save a trained model and the associated configuration # Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME) output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
with open(output_config_file, 'w') as f: output_vocab_file = os.path.join(args.output_dir, VOCAB_NAME)
f.write(model_to_save.config.to_json_string())
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_vocab_file)
# Load a trained model and config that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
config = BertConfig(output_config_file) model = BertForMultipleChoice.from_pretrained(args.output_dir, num_choices=4)
model = BertForMultipleChoice(config, num_choices=4) tokenizer = BertTokenizer.from_pretrained(args.output_dir)
model.load_state_dict(torch.load(output_model_file))
else: else:
model = BertForMultipleChoice.from_pretrained(args.bert_model, num_choices=4) model = BertForMultipleChoice.from_pretrained(args.bert_model, num_choices=4)
model.to(device) model.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment