Commit 326944d6 authored by thomwolf's avatar thomwolf
Browse files

add tensorboard to run_squad

parent d82e5dee
...@@ -34,6 +34,8 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, ...@@ -34,6 +34,8 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from tensorboardX import SummaryWriter
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
...@@ -915,9 +917,8 @@ def main(): ...@@ -915,9 +917,8 @@ def main():
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
if args.do_train: if args.do_train:
writer = SummaryWriter()
# Prepare data loader # Prepare data loader
train_examples = read_squad_examples( train_examples = read_squad_examples(
input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative) input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format( cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
...@@ -999,7 +1000,7 @@ def main(): ...@@ -999,7 +1000,7 @@ def main():
logger.info(" Num steps = %d", num_train_optimization_steps) logger.info(" Num steps = %d", num_train_optimization_steps)
model.train() model.train()
for _ in trange(int(args.num_train_epochs), desc="Epoch"): for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
if n_gpu == 1: if n_gpu == 1:
batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
...@@ -1015,6 +1016,8 @@ def main(): ...@@ -1015,6 +1016,8 @@ def main():
else: else:
loss.backward() loss.backward()
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
writer.add_scalar('loss', loss.item(), global_step)
if args.fp16: if args.fp16:
# modify learning rate with special warm up BERT uses # modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used and handles this automatically # if args.fp16 is False, BertAdam is used and handles this automatically
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment