Unverified Commit ae6b6963 authored by pcuenca's avatar pcuenca Committed by GitHub
Browse files

Allow use of pre-computed lengths when grouping by length. (#10953)

A new argument `length_column_name` has been added to
`TrainingArguments`, with default value `"length"`. If this column
exists and `group_by_length` is `True`, the train sampler will use
it for grouping rather than computing it before training starts.

This is an optimization that allows the user to prepare data for fast
processing, preventing sequential access to the dataset as described in
issue #10909.
parent 4002f95e
......@@ -496,10 +496,18 @@ class Trainer:
# Build the sampler.
if self.args.group_by_length:
if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
lengths = (
self.train_dataset[self.args.length_column_name]
if self.args.length_column_name in self.train_dataset.column_names
else None
)
else:
lengths = None
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
if self.args.world_size <= 1:
return LengthGroupedSampler(
self.train_dataset, self.args.train_batch_size, model_input_name=model_input_name
self.train_dataset, self.args.train_batch_size, lengths=lengths, model_input_name=model_input_name
)
else:
return DistributedLengthGroupedSampler(
......@@ -507,6 +515,7 @@ class Trainer:
self.args.train_batch_size,
num_replicas=self.args.world_size,
rank=self.args.process_index,
lengths=lengths,
model_input_name=model_input_name,
)
......
......@@ -277,6 +277,10 @@ class TrainingArguments:
group_by_length (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to group together samples of roughly the same legnth in the training dataset (to minimize
padding applied and be more efficient). Only useful if applying dynamic padding.
length_column_name (:obj:`str`, `optional`, defaults to :obj:`"length"`):
Column name for precomputed lengths. If the column exists, grouping by length will use these values rather
than computing them on train startup. Ignored unless :obj:`group_by_length` is :obj:`True` and the dataset
is an instance of :obj:`Dataset`.
report_to (:obj:`str` or :obj:`List[str]`, `optional`, defaults to :obj:`"all"`):
The list of integrations to report the results and logs to. Supported platforms are :obj:`"azure_ml"`,
:obj:`"comet_ml"`, :obj:`"mlflow"`, :obj:`"tensorboard"` and :obj:`"wandb"`. Use :obj:`"all"` to report to
......@@ -494,6 +498,10 @@ class TrainingArguments:
default=False,
metadata={"help": "Whether or not to group samples of roughly the same length together when batching."},
)
length_column_name: Optional[str] = field(
default="length",
metadata={"help": "Column name with precomputed lengths to use when grouping by length."},
)
report_to: Optional[List[str]] = field(
default=None, metadata={"help": "The list of integrations to report the results and logs to."}
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment