Unverified Commit 705d6536 authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Fix multiproc metrics in no_trainer examples (#16865)

parent 175da8d1
...@@ -457,12 +457,21 @@ def main(): ...@@ -457,12 +457,21 @@ def main():
break break
model.eval() model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader): for step, batch in enumerate(eval_dataloader):
outputs = model(**batch) outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader):
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
samples_seen += references.shape[0]
metric.add_batch( metric.add_batch(
predictions=accelerator.gather(predictions), predictions=predictions,
references=accelerator.gather(batch["labels"]), references=references,
) )
eval_metric = metric.compute() eval_metric = metric.compute()
......
...@@ -559,13 +559,22 @@ def main(): ...@@ -559,13 +559,22 @@ def main():
break break
model.eval() model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader): for step, batch in enumerate(eval_dataloader):
with torch.no_grad(): with torch.no_grad():
outputs = model(**batch) outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader):
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
samples_seen += references.shape[0]
metric.add_batch( metric.add_batch(
predictions=accelerator.gather(predictions), predictions=predictions,
references=accelerator.gather(batch["labels"]), references=references,
) )
eval_metric = metric.compute() eval_metric = metric.compute()
......
...@@ -567,6 +567,7 @@ def main(): ...@@ -567,6 +567,7 @@ def main():
logger.info("***** Running evaluation *****") logger.info("***** Running evaluation *****")
model.eval() model.eval()
samples_seen = 0
for step, batch in enumerate(tqdm(eval_dataloader, disable=not accelerator.is_local_main_process)): for step, batch in enumerate(tqdm(eval_dataloader, disable=not accelerator.is_local_main_process)):
outputs = model(**batch) outputs = model(**batch)
...@@ -575,9 +576,19 @@ def main(): ...@@ -575,9 +576,19 @@ def main():
) )
predictions = upsampled_logits.argmax(dim=1) predictions = upsampled_logits.argmax(dim=1)
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader):
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
samples_seen += references.shape[0]
metric.add_batch( metric.add_batch(
predictions=accelerator.gather(predictions), predictions=predictions,
references=accelerator.gather(batch["labels"]), references=references,
) )
eval_metrics = metric.compute( eval_metrics = metric.compute(
......
...@@ -628,6 +628,7 @@ def main(): ...@@ -628,6 +628,7 @@ def main():
"max_length": args.val_max_target_length if args is not None else config.max_length, "max_length": args.val_max_target_length if args is not None else config.max_length,
"num_beams": args.num_beams, "num_beams": args.num_beams,
} }
samples_seen = 0
for step, batch in enumerate(eval_dataloader): for step, batch in enumerate(eval_dataloader):
with torch.no_grad(): with torch.no_grad():
generated_tokens = accelerator.unwrap_model(model).generate( generated_tokens = accelerator.unwrap_model(model).generate(
...@@ -644,8 +645,9 @@ def main(): ...@@ -644,8 +645,9 @@ def main():
# If we did not pad to max length, we need to pad the labels too # If we did not pad to max length, we need to pad the labels too
labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id) labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
generated_tokens = accelerator.gather(generated_tokens).cpu().numpy() generated_tokens, labels = accelerator.gather((generated_tokens, labels))
labels = accelerator.gather(labels).cpu().numpy() generated_tokens = generated_tokens.cpu().numpy()
labels = labels.cpu().numpy()
if args.ignore_pad_token_for_loss: if args.ignore_pad_token_for_loss:
# Replace -100 in the labels as we can't decode them. # Replace -100 in the labels as we can't decode them.
...@@ -656,8 +658,18 @@ def main(): ...@@ -656,8 +658,18 @@ def main():
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader):
decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen]
decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen]
else:
samples_seen += decoded_labels.shape[0]
metric.add_batch(predictions=decoded_preds, references=decoded_labels) metric.add_batch(
predictions=decoded_preds,
references=decoded_labels,
)
result = metric.compute(use_stemmer=True) result = metric.compute(use_stemmer=True)
# Extract a few results from ROUGE # Extract a few results from ROUGE
result = {key: value.mid.fmeasure * 100 for key, value in result.items()} result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
......
...@@ -506,12 +506,21 @@ def main(): ...@@ -506,12 +506,21 @@ def main():
break break
model.eval() model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader): for step, batch in enumerate(eval_dataloader):
outputs = model(**batch) outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze() predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader):
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
samples_seen += references.shape[0]
metric.add_batch( metric.add_batch(
predictions=accelerator.gather(predictions), predictions=predictions,
references=accelerator.gather(batch["labels"]), references=references,
) )
eval_metric = metric.compute() eval_metric = metric.compute()
......
...@@ -658,6 +658,7 @@ def main(): ...@@ -658,6 +658,7 @@ def main():
break break
model.eval() model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader): for step, batch in enumerate(eval_dataloader):
with torch.no_grad(): with torch.no_grad():
outputs = model(**batch) outputs = model(**batch)
...@@ -666,9 +667,14 @@ def main(): ...@@ -666,9 +667,14 @@ def main():
if not args.pad_to_max_length: # necessary to pad predictions and labels for being gathered if not args.pad_to_max_length: # necessary to pad predictions and labels for being gathered
predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=-100) predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=-100)
labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100) labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
predictions_gathered, labels_gathered = accelerator.gather((predictions, labels))
predictions_gathered = accelerator.gather(predictions) # If we are in a multiprocess environment, the last batch has duplicates
labels_gathered = accelerator.gather(labels) if accelerator.num_processes > 1:
if step == len(eval_dataloader):
predictions_gathered = predictions_gathered[: len(eval_dataloader.dataset) - samples_seen]
labels_gathered = labels_gathered[: len(eval_dataloader.dataset) - samples_seen]
else:
samples_seen += labels_gathered.shape[0]
preds, refs = get_labels(predictions_gathered, labels_gathered) preds, refs = get_labels(predictions_gathered, labels_gathered)
metric.add_batch( metric.add_batch(
predictions=preds, predictions=preds,
......
...@@ -613,6 +613,7 @@ def main(): ...@@ -613,6 +613,7 @@ def main():
"max_length": args.val_max_target_length if args is not None else config.max_length, "max_length": args.val_max_target_length if args is not None else config.max_length,
"num_beams": args.num_beams, "num_beams": args.num_beams,
} }
samples_seen = 0
for step, batch in enumerate(eval_dataloader): for step, batch in enumerate(eval_dataloader):
with torch.no_grad(): with torch.no_grad():
generated_tokens = accelerator.unwrap_model(model).generate( generated_tokens = accelerator.unwrap_model(model).generate(
...@@ -641,6 +642,14 @@ def main(): ...@@ -641,6 +642,14 @@ def main():
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader):
decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen]
decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen]
else:
samples_seen += decoded_labels.shape[0]
metric.add_batch(predictions=decoded_preds, references=decoded_labels) metric.add_batch(predictions=decoded_preds, references=decoded_labels)
eval_metric = metric.compute() eval_metric = metric.compute()
logger.info({"bleu": eval_metric["score"]}) logger.info({"bleu": eval_metric["score"]})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment