Commit eebc8abb authored by thomwolf's avatar thomwolf
Browse files

clarify and unify model saving logic in examples

parent 81c7e3ec
...@@ -779,7 +779,8 @@ python run_classifier.py \ ...@@ -779,7 +779,8 @@ python run_classifier.py \
--train_batch_size 32 \ --train_batch_size 32 \
--learning_rate 2e-5 \ --learning_rate 2e-5 \
--num_train_epochs 3.0 \ --num_train_epochs 3.0 \
--output_dir /tmp/mrpc_output/ --output_dir /tmp/mrpc_output/ \
--fp16
``` ```
#### SQuAD #### SQuAD
......
...@@ -23,7 +23,6 @@ import logging ...@@ -23,7 +23,6 @@ import logging
import os import os
import random import random
import sys import sys
from io import open
import numpy as np import numpy as np
import torch import torch
...@@ -33,7 +32,7 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -33,7 +32,7 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
...@@ -92,7 +91,7 @@ class DataProcessor(object): ...@@ -92,7 +91,7 @@ class DataProcessor(object):
@classmethod @classmethod
def _read_tsv(cls, input_file, quotechar=None): def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with open(input_file, "rb") as f: with open(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = [] lines = []
for line in reader: for line in reader:
...@@ -324,6 +323,10 @@ def main(): ...@@ -324,6 +323,10 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.") help="The output directory where the model predictions and checkpoints will be written.")
## Other parameters ## Other parameters
parser.add_argument("--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument("--max_seq_length", parser.add_argument("--max_seq_length",
default=128, default=128,
type=int, type=int,
...@@ -383,9 +386,17 @@ def main(): ...@@ -383,9 +386,17 @@ def main():
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n" "0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n") "Positive power of 2: static loss scaling value.\n")
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd
print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
processors = { processors = {
"cola": ColaProcessor, "cola": ColaProcessor,
"mnli": MnliProcessor, "mnli": MnliProcessor,
...@@ -451,8 +462,9 @@ def main(): ...@@ -451,8 +462,9 @@ def main():
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model # Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank))
model = BertForSequenceClassification.from_pretrained(args.bert_model, model = BertForSequenceClassification.from_pretrained(args.bert_model,
cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)), cache_dir=cache_dir,
num_labels = num_labels) num_labels = num_labels)
if args.fp16: if args.fp16:
model.half() model.half()
...@@ -549,15 +561,21 @@ def main(): ...@@ -549,15 +561,21 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
if args.do_train: if args.do_train:
# Save a trained model and the associated configuration
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
# Load a trained model that you have fine-tuned with open(output_config_file, 'w') as f:
model_state_dict = torch.load(output_model_file) f.write(model_to_save.config.to_json_string())
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict, num_labels=num_labels)
# Load a trained model and config that you have fine-tuned
config = BertConfig(output_config_file)
model = BertForSequenceClassification(config, num_labels=num_labels)
model.load_state_dict(torch.load(output_model_file))
else:
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
model.to(device) model.to(device)
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
......
...@@ -35,7 +35,7 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -35,7 +35,7 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from pytorch_pretrained_bert.tokenization import (BasicTokenizer, from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
BertTokenizer, BertTokenizer,
...@@ -1001,14 +1001,19 @@ def main(): ...@@ -1001,14 +1001,19 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
if args.do_train: if args.do_train:
# Save a trained model and the associated configuration
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
model_state_dict = torch.load(output_model_file) with open(output_config_file, 'w') as f:
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict) f.write(model_to_save.config.to_json_string())
# Load a trained model and config that you have fine-tuned
config = BertConfig(output_config_file)
model = BertForQuestionAnswering(config)
model.load_state_dict(torch.load(output_model_file))
else: else:
model = BertForQuestionAnswering.from_pretrained(args.bert_model) model = BertForQuestionAnswering.from_pretrained(args.bert_model)
......
...@@ -469,18 +469,25 @@ def main(): ...@@ -469,18 +469,25 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned if args.do_train:
model_state_dict = torch.load(output_model_file) # Save a trained model and the associated configuration
model = BertForMultipleChoice.from_pretrained(args.bert_model, model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
state_dict=model_state_dict, output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
num_choices=4) torch.save(model_to_save.state_dict(), output_model_file)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
with open(output_config_file, 'w') as f:
f.write(model_to_save.config.to_json_string())
# Load a trained model and config that you have fine-tuned
config = BertConfig(output_config_file)
model = BertForMultipleChoice(config, num_choices=4)
model.load_state_dict(torch.load(output_model_file))
else:
model = BertForMultipleChoice.from_pretrained(args.bert_model, num_choices=4)
model.to(device) model.to(device)
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True) eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True)
eval_features = convert_examples_to_features( eval_features = convert_examples_to_features(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment