Unverified Commit 3744126c authored by CokeDong's avatar CokeDong Committed by GitHub
Browse files

Add `tgs` speed metrics (#25858)



* Add tgs metrics

* bugfix and black formatting

* workaround for tokens counting

* formating and bugfix

* Fix

* Add opt-in for tgs metrics

* make style and fix error

* Fix doc

* fix docbuild

* hf-doc-build

* fix

* test

* Update src/transformers/training_args.py

renaming
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>

* Update src/transformers/training_args.py

renaming
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>

* Fix some symbol

* test

* Update src/transformers/trainer_utils.py

match nameing patterns
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/training_args.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/trainer.py

nice
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix reviews

* Fix

* Fix black

---------
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 0188739a
......@@ -1159,6 +1159,22 @@ class Trainer:
except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader
return len(dataloader) * self.args.per_device_train_batch_size
def num_tokens(self, train_dl: DataLoader, max_steps: Optional[int] = None) -> int:
"""
Helper to get number of tokens in a [`~torch.utils.data.DataLoader`] by enumerating dataloader.
"""
train_tokens = 0
try:
for step, batch in enumerate(train_dl):
tokens = batch["input_ids"].numel()
if max_steps is not None:
return tokens * max_steps
train_tokens += tokens
return train_tokens
except KeyError:
logger.warning("Cannot get num_tokens from dataloader")
return train_tokens
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
"""HP search setup code"""
self._trial = trial
......@@ -1576,6 +1592,7 @@ class Trainer:
total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
len_dataloader = None
num_train_tokens = None
if has_length(train_dataloader):
len_dataloader = len(train_dataloader)
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
......@@ -1589,10 +1606,16 @@ class Trainer:
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
# the best we can do.
num_train_samples = args.max_steps * total_train_batch_size
if args.include_tokens_per_second:
num_train_tokens = (
self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
)
else:
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(args.num_train_epochs)
num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
if args.include_tokens_per_second:
num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs
elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
max_steps = args.max_steps
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
......@@ -1600,6 +1623,8 @@ class Trainer:
num_update_steps_per_epoch = max_steps
num_examples = total_train_batch_size * args.max_steps
num_train_samples = args.max_steps * total_train_batch_size
if args.include_tokens_per_second:
num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
else:
raise ValueError(
"args.max_steps must be set to a positive value if dataloader does not have a length, was"
......@@ -1976,7 +2001,13 @@ class Trainer:
self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
metrics = speed_metrics(
"train",
start_time,
num_samples=num_train_samples,
num_steps=self.state.max_steps,
num_tokens=num_train_tokens,
)
self.store_flos()
metrics["total_flos"] = self.state.total_flos
metrics["train_loss"] = train_loss
......
......@@ -335,7 +335,7 @@ def total_processes_number(local_rank):
return 1
def speed_metrics(split, start_time, num_samples=None, num_steps=None):
def speed_metrics(split, start_time, num_samples=None, num_steps=None, num_tokens=None):
"""
Measure and return speed performance metrics.
......@@ -346,6 +346,7 @@ def speed_metrics(split, start_time, num_samples=None, num_steps=None):
- split: name to prefix metric (like train, eval, test...)
- start_time: operation start time
- num_samples: number of samples processed
- num_tokens: number of tokens processed
"""
runtime = time.time() - start_time
result = {f"{split}_runtime": round(runtime, 4)}
......@@ -357,6 +358,9 @@ def speed_metrics(split, start_time, num_samples=None, num_steps=None):
if num_steps is not None:
steps_per_second = num_steps / runtime
result[f"{split}_steps_per_second"] = round(steps_per_second, 3)
if num_tokens is not None:
tokens_per_second = num_tokens / runtime
result[f"{split}_tokens_per_second"] = round(tokens_per_second, 3)
return result
......
......@@ -633,6 +633,12 @@ class TrainingArguments:
Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
This flag is experimental and subject to change in future releases.
include_tokens_per_second (`bool`, *optional*):
Whether or not to compute the number of tokens per second per device for training speed metrics.
This will iterate over the entire training dataloader once beforehand,
and will slow down the entire process.
"""
framework = "pt"
......@@ -1232,6 +1238,11 @@ class TrainingArguments:
},
)
include_tokens_per_second: Optional[bool] = field(
default=False,
metadata={"help": "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)."},
)
def __post_init__(self):
# expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment