Unverified Commit 21f01964 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #58 from lliimsft/master

Bug fix in examples;correct t_total for distributed training;run pred…
parents 32167cdf 0aaedcc0
...@@ -33,6 +33,7 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -33,6 +33,7 @@ from torch.utils.data.distributed import DistributedSampler
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForSequenceClassification from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.optimization import BertAdam from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -155,8 +156,8 @@ class MnliProcessor(DataProcessor): ...@@ -155,8 +156,8 @@ class MnliProcessor(DataProcessor):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, line[0]) guid = "%s-%s" % (set_type, line[0])
text_a = line[8]) text_a = line[8]
text_b = line[9]) text_b = line[9]
label = line[-1] label = line[-1]
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
...@@ -482,7 +483,7 @@ def main(): ...@@ -482,7 +483,7 @@ def main():
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
# Prepare model # Prepare model
model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list), model = BertForSequenceClassification.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank)) cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
if args.fp16: if args.fp16:
model.half() model.half()
...@@ -507,10 +508,13 @@ def main(): ...@@ -507,10 +508,13 @@ def main():
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0} {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
] ]
t_total = num_train_steps
if args.local_rank != -1:
t_total = t_total // torch.distributed.get_world_size()
optimizer = BertAdam(optimizer_grouped_parameters, optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate, lr=args.learning_rate,
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
t_total=num_train_steps) t_total=t_total)
global_step = 0 global_step = 0
if args.do_train: if args.do_train:
...@@ -571,7 +575,7 @@ def main(): ...@@ -571,7 +575,7 @@ def main():
model.zero_grad() model.zero_grad()
global_step += 1 global_step += 1
if args.do_eval: if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = processor.get_dev_examples(args.data_dir) eval_examples = processor.get_dev_examples(args.data_dir)
eval_features = convert_examples_to_features( eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer) eval_examples, label_list, args.max_seq_length, tokenizer)
...@@ -583,10 +587,8 @@ def main(): ...@@ -583,10 +587,8 @@ def main():
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1: # Run prediction for full data
eval_sampler = SequentialSampler(eval_data) eval_sampler = SequentialSampler(eval_data)
else:
eval_sampler = DistributedSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
model.eval() model.eval()
......
...@@ -25,6 +25,7 @@ import json ...@@ -25,6 +25,7 @@ import json
import math import math
import os import os
import random import random
import pickle
from tqdm import tqdm, trange from tqdm import tqdm, trange
import numpy as np import numpy as np
...@@ -35,6 +36,7 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -35,6 +36,7 @@ from torch.utils.data.distributed import DistributedSampler
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.optimization import BertAdam from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -749,6 +751,10 @@ def main(): ...@@ -749,6 +751,10 @@ def main():
type=int, type=int,
default=1, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.") help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--do_lower_case",
default=True,
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--local_rank", parser.add_argument("--local_rank",
type=int, type=int,
default=-1, default=-1,
...@@ -845,20 +851,34 @@ def main(): ...@@ -845,20 +851,34 @@ def main():
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0} {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
] ]
t_total = num_train_steps
if args.local_rank != -1:
t_total = t_total // torch.distributed.get_world_size()
optimizer = BertAdam(optimizer_grouped_parameters, optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate, lr=args.learning_rate,
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
t_total=num_train_steps) t_total=t_total)
global_step = 0 global_step = 0
if args.do_train: if args.do_train:
train_features = convert_examples_to_features( cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
examples=train_examples, args.bert_model, str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length))
tokenizer=tokenizer, train_features = None
max_seq_length=args.max_seq_length, try:
doc_stride=args.doc_stride, with open(cached_train_features_file, "rb") as reader:
max_query_length=args.max_query_length, train_features = pickle.load(reader)
is_training=True) except:
train_features = convert_examples_to_features(
examples=train_examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=True)
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving train features into cached file %s", cached_train_features_file)
with open(cached_train_features_file, "wb") as writer:
train_features = pickle.dump(train_features, writer)
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num orig examples = %d", len(train_examples))
logger.info(" Num split examples = %d", len(train_features)) logger.info(" Num split examples = %d", len(train_features))
...@@ -913,7 +933,7 @@ def main(): ...@@ -913,7 +933,7 @@ def main():
model.zero_grad() model.zero_grad()
global_step += 1 global_step += 1
if args.do_predict: if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = read_squad_examples( eval_examples = read_squad_examples(
input_file=args.predict_file, is_training=False) input_file=args.predict_file, is_training=False)
eval_features = convert_examples_to_features( eval_features = convert_examples_to_features(
...@@ -934,10 +954,8 @@ def main(): ...@@ -934,10 +954,8 @@ def main():
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
if args.local_rank == -1: # Run prediction for full data
eval_sampler = SequentialSampler(eval_data) eval_sampler = SequentialSampler(eval_data)
else:
eval_sampler = DistributedSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)
model.eval() model.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment