Commit 89d47230 authored by thomwolf's avatar thomwolf
Browse files

clean up classification model output

parent 7f7c41b0
...@@ -546,7 +546,7 @@ def main(): ...@@ -546,7 +546,7 @@ def main():
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
batch = tuple(t.to(device) for t in batch) batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch input_ids, input_mask, segment_ids, label_ids = batch
loss, _ = model(input_ids, segment_ids, input_mask, label_ids) loss = model(input_ids, segment_ids, input_mask, label_ids)
if n_gpu > 1: if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu. loss = loss.mean() # mean() to average on multi-gpu.
if args.fp16 and args.loss_scale != 1.0: if args.fp16 and args.loss_scale != 1.0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment