"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "3df3b9d4bf006ab193b3c1257f3436b9fdb91759"
Unverified Commit ae06bce8 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

exclude jit time from the speed metric calculation of evaluation and prediction (#20553)


Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent 25e10da4
......@@ -51,10 +51,13 @@ class QuestionAnsweringTrainer(Trainer):
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
finally:
self.compute_metrics = compute_metrics
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
......@@ -74,7 +77,7 @@ class QuestionAnsweringTrainer(Trainer):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
metrics.update(output.metrics)
else:
metrics = {}
metrics = output.metrics
if self.args.should_log:
# Only the main node log the results by default
......@@ -103,10 +106,13 @@ class QuestionAnsweringTrainer(Trainer):
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
finally:
self.compute_metrics = compute_metrics
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
......
......@@ -71,10 +71,13 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
finally:
self.compute_metrics = compute_metrics
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
......@@ -94,9 +97,9 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
output.metrics.update(metrics)
metrics.update(output.metrics)
else:
metrics = {}
metrics = output.metrics
if self.args.should_log:
# Only the main node log the results by default
......@@ -106,7 +109,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
return metrics
def predict(
......@@ -119,6 +122,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
# Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics
self.compute_metrics = None
start_time = time.time()
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
try:
output = eval_loop(
......@@ -128,10 +132,22 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
finally:
self.compute_metrics = compute_metrics
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
start_time,
num_samples=output.num_samples,
num_steps=math.ceil(output.num_samples / total_batch_size),
)
)
if self.post_process_function is None or self.compute_metrics is None:
return output
......@@ -142,5 +158,5 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
metrics.update(output.metrics)
return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)
......@@ -766,6 +766,7 @@ def parse_log_history(log_history):
_ = metrics.pop("eval_runtime", None)
_ = metrics.pop("eval_samples_per_second", None)
_ = metrics.pop("eval_steps_per_second", None)
_ = metrics.pop("eval_jit_compilation_time", None)
values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
for k, v in metrics.items():
if k == "eval_loss":
......
......@@ -1345,7 +1345,9 @@ class Trainer:
model = nn.DataParallel(model)
if self.args.jit_mode_eval:
start_time = time.time()
model = self.torch_jit_model_eval(model, dataloader, training)
self.jit_compilation_time = round(time.time() - start_time, 4)
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
......@@ -2819,6 +2821,8 @@ class Trainer:
)
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
......@@ -2886,6 +2890,8 @@ class Trainer:
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
......@@ -3102,6 +3108,8 @@ class Trainer:
if all_losses is not None:
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
if hasattr(self, "jit_compilation_time"):
metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
......
......@@ -224,7 +224,11 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
loss = metrics.pop("eval_loss", None)
_ = metrics.pop("epoch", None)
# Remove speed metrics
speed_metrics = [m for m in metrics.keys() if m.endswith("_runtime") or m.endswith("_per_second")]
speed_metrics = [
m
for m in metrics.keys()
if m.endswith("_runtime") or m.endswith("_per_second") or m.endswith("_compilation_time")
]
for sm in speed_metrics:
_ = metrics.pop(sm, None)
return loss if len(metrics) == 0 else sum(metrics.values())
......
......@@ -339,6 +339,7 @@ class NotebookProgressCallback(TrainerCallback):
_ = metrics.pop(f"{metric_key_prefix}_runtime", None)
_ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None)
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
_ = metrics.pop(f"{metric_key_prefix}_jit_compilation_time", None)
for k, v in metrics.items():
if k == f"{metric_key_prefix}_loss":
values["Validation Loss"] = v
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment