"docker/git@developer.sourcefind.cn:modelzoo/rtmdet_mmcv.git" did not exist on "dec8177b41b8d4db9711d9d97d1c473863e10877"
Unverified Commit ae06bce8 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

exclude jit time from the speed metric calculation of evaluation and prediction (#20553)


Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent 25e10da4
...@@ -51,10 +51,13 @@ class QuestionAnsweringTrainer(Trainer): ...@@ -51,10 +51,13 @@ class QuestionAnsweringTrainer(Trainer):
# self.args.prediction_loss_only # self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None, prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys, ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
) )
finally: finally:
self.compute_metrics = compute_metrics self.compute_metrics = compute_metrics
total_batch_size = self.args.eval_batch_size * self.args.world_size total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update( output.metrics.update(
speed_metrics( speed_metrics(
metric_key_prefix, metric_key_prefix,
...@@ -74,7 +77,7 @@ class QuestionAnsweringTrainer(Trainer): ...@@ -74,7 +77,7 @@ class QuestionAnsweringTrainer(Trainer):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
metrics.update(output.metrics) metrics.update(output.metrics)
else: else:
metrics = {} metrics = output.metrics
if self.args.should_log: if self.args.should_log:
# Only the main node log the results by default # Only the main node log the results by default
...@@ -103,10 +106,13 @@ class QuestionAnsweringTrainer(Trainer): ...@@ -103,10 +106,13 @@ class QuestionAnsweringTrainer(Trainer):
# self.args.prediction_loss_only # self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None, prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys, ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
) )
finally: finally:
self.compute_metrics = compute_metrics self.compute_metrics = compute_metrics
total_batch_size = self.args.eval_batch_size * self.args.world_size total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update( output.metrics.update(
speed_metrics( speed_metrics(
metric_key_prefix, metric_key_prefix,
......
...@@ -71,10 +71,13 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -71,10 +71,13 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
# self.args.prediction_loss_only # self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None, prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys, ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
) )
finally: finally:
self.compute_metrics = compute_metrics self.compute_metrics = compute_metrics
total_batch_size = self.args.eval_batch_size * self.args.world_size total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update( output.metrics.update(
speed_metrics( speed_metrics(
metric_key_prefix, metric_key_prefix,
...@@ -94,9 +97,9 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -94,9 +97,9 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
if not key.startswith(f"{metric_key_prefix}_"): if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
output.metrics.update(metrics) metrics.update(output.metrics)
else: else:
metrics = {} metrics = output.metrics
if self.args.should_log: if self.args.should_log:
# Only the main node log the results by default # Only the main node log the results by default
...@@ -106,7 +109,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -106,7 +109,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report()) xm.master_print(met.metrics_report())
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
return metrics return metrics
def predict( def predict(
...@@ -119,6 +122,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -119,6 +122,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
# Temporarily disable metric computation, we will do it in the loop here. # Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics compute_metrics = self.compute_metrics
self.compute_metrics = None self.compute_metrics = None
start_time = time.time()
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
try: try:
output = eval_loop( output = eval_loop(
...@@ -128,10 +132,22 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -128,10 +132,22 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
# self.args.prediction_loss_only # self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None, prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys, ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
) )
finally: finally:
self.compute_metrics = compute_metrics self.compute_metrics = compute_metrics
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
start_time,
num_samples=output.num_samples,
num_steps=math.ceil(output.num_samples / total_batch_size),
)
)
if self.post_process_function is None or self.compute_metrics is None: if self.post_process_function is None or self.compute_metrics is None:
return output return output
...@@ -142,5 +158,5 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -142,5 +158,5 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
for key in list(metrics.keys()): for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"): if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
metrics.update(output.metrics)
return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics) return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)
...@@ -766,6 +766,7 @@ def parse_log_history(log_history): ...@@ -766,6 +766,7 @@ def parse_log_history(log_history):
_ = metrics.pop("eval_runtime", None) _ = metrics.pop("eval_runtime", None)
_ = metrics.pop("eval_samples_per_second", None) _ = metrics.pop("eval_samples_per_second", None)
_ = metrics.pop("eval_steps_per_second", None) _ = metrics.pop("eval_steps_per_second", None)
_ = metrics.pop("eval_jit_compilation_time", None)
values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step} values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
for k, v in metrics.items(): for k, v in metrics.items():
if k == "eval_loss": if k == "eval_loss":
......
...@@ -1345,7 +1345,9 @@ class Trainer: ...@@ -1345,7 +1345,9 @@ class Trainer:
model = nn.DataParallel(model) model = nn.DataParallel(model)
if self.args.jit_mode_eval: if self.args.jit_mode_eval:
start_time = time.time()
model = self.torch_jit_model_eval(model, dataloader, training) model = self.torch_jit_model_eval(model, dataloader, training)
self.jit_compilation_time = round(time.time() - start_time, 4)
# Note: in torch.distributed mode, there's no point in wrapping the model # Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways. # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
...@@ -2819,6 +2821,8 @@ class Trainer: ...@@ -2819,6 +2821,8 @@ class Trainer:
) )
total_batch_size = self.args.eval_batch_size * self.args.world_size total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update( output.metrics.update(
speed_metrics( speed_metrics(
metric_key_prefix, metric_key_prefix,
...@@ -2886,6 +2890,8 @@ class Trainer: ...@@ -2886,6 +2890,8 @@ class Trainer:
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
) )
total_batch_size = self.args.eval_batch_size * self.args.world_size total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update( output.metrics.update(
speed_metrics( speed_metrics(
metric_key_prefix, metric_key_prefix,
...@@ -3102,6 +3108,8 @@ class Trainer: ...@@ -3102,6 +3108,8 @@ class Trainer:
if all_losses is not None: if all_losses is not None:
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
if hasattr(self, "jit_compilation_time"):
metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
# Prefix all keys with metric_key_prefix + '_' # Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()): for key in list(metrics.keys()):
......
...@@ -224,7 +224,11 @@ def default_compute_objective(metrics: Dict[str, float]) -> float: ...@@ -224,7 +224,11 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
loss = metrics.pop("eval_loss", None) loss = metrics.pop("eval_loss", None)
_ = metrics.pop("epoch", None) _ = metrics.pop("epoch", None)
# Remove speed metrics # Remove speed metrics
speed_metrics = [m for m in metrics.keys() if m.endswith("_runtime") or m.endswith("_per_second")] speed_metrics = [
m
for m in metrics.keys()
if m.endswith("_runtime") or m.endswith("_per_second") or m.endswith("_compilation_time")
]
for sm in speed_metrics: for sm in speed_metrics:
_ = metrics.pop(sm, None) _ = metrics.pop(sm, None)
return loss if len(metrics) == 0 else sum(metrics.values()) return loss if len(metrics) == 0 else sum(metrics.values())
......
...@@ -339,6 +339,7 @@ class NotebookProgressCallback(TrainerCallback): ...@@ -339,6 +339,7 @@ class NotebookProgressCallback(TrainerCallback):
_ = metrics.pop(f"{metric_key_prefix}_runtime", None) _ = metrics.pop(f"{metric_key_prefix}_runtime", None)
_ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None) _ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None)
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None) _ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
_ = metrics.pop(f"{metric_key_prefix}_jit_compilation_time", None)
for k, v in metrics.items(): for k, v in metrics.items():
if k == f"{metric_key_prefix}_loss": if k == f"{metric_key_prefix}_loss":
values["Validation Loss"] = v values["Validation Loss"] = v
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment