"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "83303bc73efcd44d037833e57234b772c9810fda"
Unverified Commit afe479ad authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

[Trainer] Report both steps and num samples per second (#11818)



* [Trainer] Report both steps and num samples per second

* Fix batch number

* Update src/transformers/trainer_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Address review comments
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
parent eaab9397
...@@ -518,6 +518,7 @@ def parse_log_history(log_history): ...@@ -518,6 +518,7 @@ def parse_log_history(log_history):
step = metrics.pop("step", None) step = metrics.pop("step", None)
_ = metrics.pop("eval_runtime", None) _ = metrics.pop("eval_runtime", None)
_ = metrics.pop("eval_samples_per_second", None) _ = metrics.pop("eval_samples_per_second", None)
_ = metrics.pop("eval_steps_per_second", None)
values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step} values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
for k, v in metrics.items(): for k, v in metrics.items():
if k == "eval_loss": if k == "eval_loss":
...@@ -537,7 +538,7 @@ def parse_log_history(log_history): ...@@ -537,7 +538,7 @@ def parse_log_history(log_history):
for key, value in log_history[idx].items(): for key, value in log_history[idx].items():
if key.startswith("eval_"): if key.startswith("eval_"):
key = key[5:] key = key[5:]
if key not in ["runtime", "samples_per_second", "epoch", "step"]: if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]:
camel_cased_key = " ".join([part.capitalize() for part in key.split("_")]) camel_cased_key = " ".join([part.capitalize() for part in key.split("_")])
eval_results[camel_cased_key] = value eval_results[camel_cased_key] = value
return train_log, lines, eval_results return train_log, lines, eval_results
......
...@@ -1077,6 +1077,7 @@ class Trainer: ...@@ -1077,6 +1077,7 @@ class Trainer:
# number of training epochs: num_train_epochs # number of training epochs: num_train_epochs
# number of training steps per epoch: num_update_steps_per_epoch # number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps # total number of training steps to execute: max_steps
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
if train_dataset_is_sized: if train_dataset_is_sized:
num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
...@@ -1085,14 +1086,19 @@ class Trainer: ...@@ -1085,14 +1086,19 @@ class Trainer:
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
args.max_steps % num_update_steps_per_epoch > 0 args.max_steps % num_update_steps_per_epoch > 0
) )
# May be slightly incorrect if the last batch in the training datalaoder has a smaller size but it's
# the best we can do.
num_train_samples = args.max_steps * total_train_batch_size
else: else:
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(args.num_train_epochs) num_train_epochs = math.ceil(args.num_train_epochs)
num_train_samples = len(self.train_dataset) * args.num_train_epochs
else: else:
# see __init__. max_steps is set when the dataset has no __len__ # see __init__. max_steps is set when the dataset has no __len__
max_steps = args.max_steps max_steps = args.max_steps
num_train_epochs = int(args.num_train_epochs) num_train_epochs = int(args.num_train_epochs)
num_update_steps_per_epoch = max_steps num_update_steps_per_epoch = max_steps
num_train_samples = args.max_steps * total_train_batch_size
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa debug_overflow = DebugUnderflowOverflow(self.model) # noqa
...@@ -1130,14 +1136,6 @@ class Trainer: ...@@ -1130,14 +1136,6 @@ class Trainer:
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
# Train! # Train!
if is_torch_tpu_available():
world_size = xm.xrt_world_size()
elif args.local_rank != -1:
world_size = dist.get_world_size()
else:
world_size = 1
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size
num_examples = ( num_examples = (
self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps
) )
...@@ -1359,7 +1357,7 @@ class Trainer: ...@@ -1359,7 +1357,7 @@ class Trainer:
self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False
) )
metrics = speed_metrics("train", start_time, self.state.max_steps) metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
self.store_flos() self.store_flos()
metrics["total_flos"] = self.state.total_flos metrics["total_flos"] = self.state.total_flos
self.log(metrics) self.log(metrics)
...@@ -2009,7 +2007,15 @@ class Trainer: ...@@ -2009,7 +2007,15 @@ class Trainer:
metric_key_prefix=metric_key_prefix, metric_key_prefix=metric_key_prefix,
) )
output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples)) total_batch_size = self.args.eval_batch_size * self.args.world_size
output.metrics.update(
speed_metrics(
metric_key_prefix,
start_time,
num_samples=output.num_samples,
num_steps=math.ceil(output.num_samples / total_batch_size),
)
)
self.log(output.metrics) self.log(output.metrics)
...@@ -2066,7 +2072,15 @@ class Trainer: ...@@ -2066,7 +2072,15 @@ class Trainer:
output = eval_loop( output = eval_loop(
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
) )
output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples)) total_batch_size = self.args.eval_batch_size * self.args.world_size
output.metrics.update(
speed_metrics(
metric_key_prefix,
start_time,
num_samples=output.num_samples,
num_steps=math.ceil(output.num_samples / total_batch_size),
)
)
self._memory_tracker.stop_and_update_metrics(output.metrics) self._memory_tracker.stop_and_update_metrics(output.metrics)
......
...@@ -158,7 +158,7 @@ def default_compute_objective(metrics: Dict[str, float]) -> float: ...@@ -158,7 +158,7 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
loss = metrics.pop("eval_loss", None) loss = metrics.pop("eval_loss", None)
_ = metrics.pop("epoch", None) _ = metrics.pop("epoch", None)
# Remove speed metrics # Remove speed metrics
speed_metrics = [m for m in metrics.keys() if m.endswith("_runtime") or m.endswith("_samples_per_second")] speed_metrics = [m for m in metrics.keys() if m.endswith("_runtime") or m.endswith("_per_second")]
for sm in speed_metrics: for sm in speed_metrics:
_ = metrics.pop(sm, None) _ = metrics.pop(sm, None)
return loss if len(metrics) == 0 else sum(metrics.values()) return loss if len(metrics) == 0 else sum(metrics.values())
...@@ -232,7 +232,7 @@ def total_processes_number(local_rank): ...@@ -232,7 +232,7 @@ def total_processes_number(local_rank):
return 1 return 1
def speed_metrics(split, start_time, num_samples=None): def speed_metrics(split, start_time, num_samples=None, num_steps=None):
""" """
Measure and return speed performance metrics. Measure and return speed performance metrics.
...@@ -248,8 +248,11 @@ def speed_metrics(split, start_time, num_samples=None): ...@@ -248,8 +248,11 @@ def speed_metrics(split, start_time, num_samples=None):
runtime = time.time() - start_time runtime = time.time() - start_time
result = {f"{split}_runtime": round(runtime, 4)} result = {f"{split}_runtime": round(runtime, 4)}
if num_samples is not None: if num_samples is not None:
samples_per_second = 1 / (runtime / num_samples) samples_per_second = num_samples / runtime
result[f"{split}_samples_per_second"] = round(samples_per_second, 3) result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
if num_steps is not None:
steps_per_second = num_steps / runtime
result[f"{split}_steps_per_second"] = round(steps_per_second, 3)
return result return result
......
...@@ -327,6 +327,7 @@ class NotebookProgressCallback(TrainerCallback): ...@@ -327,6 +327,7 @@ class NotebookProgressCallback(TrainerCallback):
_ = metrics.pop("epoch", None) _ = metrics.pop("epoch", None)
_ = metrics.pop(f"{metric_key_prefix}_runtime", None) _ = metrics.pop(f"{metric_key_prefix}_runtime", None)
_ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None) _ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None)
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
for k, v in metrics.items(): for k, v in metrics.items():
if k == f"{metric_key_prefix}_loss": if k == f"{metric_key_prefix}_loss":
values["Validation Loss"] = v values["Validation Loss"] = v
......
...@@ -316,6 +316,8 @@ class TrainerIntegrationCommon: ...@@ -316,6 +316,8 @@ class TrainerIntegrationCommon:
_ = log1.pop("train_runtime", None) _ = log1.pop("train_runtime", None)
_ = log.pop("train_samples_per_second", None) _ = log.pop("train_samples_per_second", None)
_ = log1.pop("train_samples_per_second", None) _ = log1.pop("train_samples_per_second", None)
_ = log.pop("train_steps_per_second", None)
_ = log1.pop("train_steps_per_second", None)
self.assertEqual(log, log1) self.assertEqual(log, log1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment