"vscode:/vscode.git/clone" did not exist on "ea55bd86b9a452c87c5383afc707ab7d710a3043"
Unverified Commit d19b5a90 authored by xiangdong's avatar xiangdong Committed by GitHub
Browse files

Exclude torch.compile time from metrics computation (#31443)

* exclude compile time from metrics computation

* fix the quality issue
parent 2aa2a144
......@@ -3670,6 +3670,8 @@ class Trainer:
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
......@@ -3739,6 +3741,8 @@ class Trainer:
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
......@@ -3777,11 +3781,13 @@ class Trainer:
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
if len(self.accelerator._models) == 0 and model is self.model:
start_time = time.time()
model = (
self.accelerator.prepare(model)
if self.is_deepspeed_enabled
else self.accelerator.prepare_model(model, evaluation_mode=True)
)
self.model_preparation_time = round(time.time() - start_time, 4)
if self.is_fsdp_enabled:
self.model = model
......@@ -3954,6 +3960,8 @@ class Trainer:
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
if hasattr(self, "jit_compilation_time"):
metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
if hasattr(self, "model_preparation_time"):
metrics[f"{metric_key_prefix}_model_preparation_time"] = self.model_preparation_time
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment