Commit dbc318a4 authored by thomwolf's avatar thomwolf
Browse files

cleaning up - speeding up a bit multi-gpu

parent 6bb7510a
...@@ -467,6 +467,6 @@ class BertForQuestionAnswering(nn.Module): ...@@ -467,6 +467,6 @@ class BertForQuestionAnswering(nn.Module):
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions) end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
return total_loss, (start_logits, end_logits) return total_loss
else: else:
return start_logits, end_logits return start_logits, end_logits
...@@ -514,13 +514,13 @@ def main(): ...@@ -514,13 +514,13 @@ def main():
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train() model.train()
for epoch in trange(int(args.num_train_epochs), desc="Epoch"): for _ in trange(int(args.num_train_epochs), desc="Epoch"):
tr_loss = 0 tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0 nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
batch = tuple(t.to(device) for t in batch) batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch input_ids, input_mask, segment_ids, label_ids = batch
loss, _ = model(input_ids, segment_ids, input_mask, label_ids) loss = model(input_ids, segment_ids, input_mask, label_ids)
if n_gpu > 1: if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu. loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1: if args.gradient_accumulation_steps > 1:
...@@ -564,6 +564,7 @@ def main(): ...@@ -564,6 +564,7 @@ def main():
segment_ids = segment_ids.to(device) segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device) label_ids = label_ids.to(device)
with torch.no_grad():
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids) tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
logits = logits.detach().cpu().numpy() logits = logits.detach().cpu().numpy()
......
...@@ -855,11 +855,11 @@ def main(): ...@@ -855,11 +855,11 @@ def main():
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train() model.train()
for epoch in trange(int(args.num_train_epochs), desc="Epoch"): for _ in trange(int(args.num_train_epochs), desc="Epoch"):
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
batch = tuple(t.to(device) for t in batch) batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, start_positions, end_positions = batch input_ids, input_mask, segment_ids, start_positions, end_positions = batch
loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions) loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
if n_gpu > 1: if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu. loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1: if args.gradient_accumulation_steps > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment