Unverified Commit 98d88b23 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[`run_(clm|mlm).py` examples] add streaming dataset support (#21343)

* [run_clm example] add streaming dataset support

* unrefactor kwargs

* fix

* fix

* require datasets>=2.0.0

* port to mlm
parent 95be242a
...@@ -174,6 +174,9 @@ concatenates all texts and then splits them in blocks of the same length). ...@@ -174,6 +174,9 @@ concatenates all texts and then splits them in blocks of the same length).
**Note:** On TPU, you should use the flag `--pad_to_max_length` in conjunction with the `--line_by_line` flag to make **Note:** On TPU, you should use the flag `--pad_to_max_length` in conjunction with the `--line_by_line` flag to make
sure all your batches have the same length. sure all your batches have the same length.
## Streaming
To use the streaming dataset mode which can be very useful for large datasets, add `--streaming` to the command line. This is currently supported by `run_mlm.py` and `run_clm.py`.
## Creating a model on the fly ## Creating a model on the fly
......
...@@ -173,7 +173,7 @@ class DataTrainingArguments: ...@@ -173,7 +173,7 @@ class DataTrainingArguments:
) )
}, },
) )
streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
block_size: Optional[int] = field( block_size: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
...@@ -202,6 +202,9 @@ class DataTrainingArguments: ...@@ -202,6 +202,9 @@ class DataTrainingArguments:
) )
def __post_init__(self): def __post_init__(self):
if self.streaming:
require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
if self.dataset_name is None and self.train_file is None and self.validation_file is None: if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.") raise ValueError("Need either a dataset name or a training/validation file.")
else: else:
...@@ -285,6 +288,7 @@ def main(): ...@@ -285,6 +288,7 @@ def main():
data_args.dataset_config_name, data_args.dataset_config_name,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
) )
if "validation" not in raw_datasets.keys(): if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset( raw_datasets["validation"] = load_dataset(
...@@ -293,6 +297,7 @@ def main(): ...@@ -293,6 +297,7 @@ def main():
split=f"train[:{data_args.validation_split_percentage}%]", split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
) )
raw_datasets["train"] = load_dataset( raw_datasets["train"] = load_dataset(
data_args.dataset_name, data_args.dataset_name,
...@@ -300,6 +305,7 @@ def main(): ...@@ -300,6 +305,7 @@ def main():
split=f"train[{data_args.validation_split_percentage}%:]", split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
) )
else: else:
data_files = {} data_files = {}
...@@ -413,9 +419,15 @@ def main(): ...@@ -413,9 +419,15 @@ def main():
# Preprocessing the datasets. # Preprocessing the datasets.
# First we tokenize all the texts. # First we tokenize all the texts.
if training_args.do_train: if training_args.do_train:
column_names = raw_datasets["train"].column_names if data_args.streaming:
column_names = raw_datasets["train"].features.keys()
else:
column_names = raw_datasets["train"].column_names
else: else:
column_names = raw_datasets["validation"].column_names if data_args.streaming:
column_names = raw_datasets["validation"].features.keys()
else:
column_names = raw_datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0] text_column_name = "text" if "text" in column_names else column_names[0]
# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
...@@ -433,14 +445,21 @@ def main(): ...@@ -433,14 +445,21 @@ def main():
return output return output
with training_args.main_process_first(desc="dataset map tokenization"): with training_args.main_process_first(desc="dataset map tokenization"):
tokenized_datasets = raw_datasets.map( if not data_args.streaming:
tokenize_function, tokenized_datasets = raw_datasets.map(
batched=True, tokenize_function,
num_proc=data_args.preprocessing_num_workers, batched=True,
remove_columns=column_names, num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache, remove_columns=column_names,
desc="Running tokenizer on dataset", load_from_cache_file=not data_args.overwrite_cache,
) desc="Running tokenizer on dataset",
)
else:
tokenized_datasets = raw_datasets.map(
tokenize_function,
batched=True,
remove_columns=column_names,
)
if data_args.block_size is None: if data_args.block_size is None:
block_size = tokenizer.model_max_length block_size = tokenizer.model_max_length
...@@ -483,13 +502,19 @@ def main(): ...@@ -483,13 +502,19 @@ def main():
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
with training_args.main_process_first(desc="grouping texts together"): with training_args.main_process_first(desc="grouping texts together"):
lm_datasets = tokenized_datasets.map( if not data_args.streaming:
group_texts, lm_datasets = tokenized_datasets.map(
batched=True, group_texts,
num_proc=data_args.preprocessing_num_workers, batched=True,
load_from_cache_file=not data_args.overwrite_cache, num_proc=data_args.preprocessing_num_workers,
desc=f"Grouping texts in chunks of {block_size}", load_from_cache_file=not data_args.overwrite_cache,
) desc=f"Grouping texts in chunks of {block_size}",
)
else:
lm_datasets = tokenized_datasets.map(
group_texts,
batched=True,
)
if training_args.do_train: if training_args.do_train:
if "train" not in tokenized_datasets: if "train" not in tokenized_datasets:
......
...@@ -197,8 +197,12 @@ class DataTrainingArguments: ...@@ -197,8 +197,12 @@ class DataTrainingArguments:
) )
}, },
) )
streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
def __post_init__(self): def __post_init__(self):
if self.streaming:
require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
if self.dataset_name is None and self.train_file is None and self.validation_file is None: if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.") raise ValueError("Need either a dataset name or a training/validation file.")
else: else:
...@@ -285,6 +289,7 @@ def main(): ...@@ -285,6 +289,7 @@ def main():
data_args.dataset_config_name, data_args.dataset_config_name,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
) )
if "validation" not in raw_datasets.keys(): if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset( raw_datasets["validation"] = load_dataset(
...@@ -293,6 +298,7 @@ def main(): ...@@ -293,6 +298,7 @@ def main():
split=f"train[:{data_args.validation_split_percentage}%]", split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
) )
raw_datasets["train"] = load_dataset( raw_datasets["train"] = load_dataset(
data_args.dataset_name, data_args.dataset_name,
...@@ -300,6 +306,7 @@ def main(): ...@@ -300,6 +306,7 @@ def main():
split=f"train[{data_args.validation_split_percentage}%:]", split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
) )
else: else:
data_files = {} data_files = {}
...@@ -398,9 +405,15 @@ def main(): ...@@ -398,9 +405,15 @@ def main():
# Preprocessing the datasets. # Preprocessing the datasets.
# First we tokenize all the texts. # First we tokenize all the texts.
if training_args.do_train: if training_args.do_train:
column_names = raw_datasets["train"].column_names if data_args.streaming:
column_names = raw_datasets["train"].features.keys()
else:
column_names = raw_datasets["train"].column_names
else: else:
column_names = raw_datasets["validation"].column_names if data_args.streaming:
column_names = raw_datasets["validation"].features.keys()
else:
column_names = raw_datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0] text_column_name = "text" if "text" in column_names else column_names[0]
if data_args.max_seq_length is None: if data_args.max_seq_length is None:
...@@ -439,14 +452,21 @@ def main(): ...@@ -439,14 +452,21 @@ def main():
) )
with training_args.main_process_first(desc="dataset map tokenization"): with training_args.main_process_first(desc="dataset map tokenization"):
tokenized_datasets = raw_datasets.map( if not data_args.streaming:
tokenize_function, tokenized_datasets = raw_datasets.map(
batched=True, tokenize_function,
num_proc=data_args.preprocessing_num_workers, batched=True,
remove_columns=[text_column_name], num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache, remove_columns=[text_column_name],
desc="Running tokenizer on dataset line_by_line", load_from_cache_file=not data_args.overwrite_cache,
) desc="Running tokenizer on dataset line_by_line",
)
else:
tokenized_datasets = raw_datasets.map(
tokenize_function,
batched=True,
remove_columns=[text_column_name],
)
else: else:
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
...@@ -455,14 +475,21 @@ def main(): ...@@ -455,14 +475,21 @@ def main():
return tokenizer(examples[text_column_name], return_special_tokens_mask=True) return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
with training_args.main_process_first(desc="dataset map tokenization"): with training_args.main_process_first(desc="dataset map tokenization"):
tokenized_datasets = raw_datasets.map( if not data_args.streaming:
tokenize_function, tokenized_datasets = raw_datasets.map(
batched=True, tokenize_function,
num_proc=data_args.preprocessing_num_workers, batched=True,
remove_columns=column_names, num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache, remove_columns=column_names,
desc="Running tokenizer on every text in dataset", load_from_cache_file=not data_args.overwrite_cache,
) desc="Running tokenizer on every text in dataset",
)
else:
tokenized_datasets = raw_datasets.map(
tokenize_function,
batched=True,
remove_columns=column_names,
)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of # Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length. # max_seq_length.
...@@ -489,13 +516,19 @@ def main(): ...@@ -489,13 +516,19 @@ def main():
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
with training_args.main_process_first(desc="grouping texts together"): with training_args.main_process_first(desc="grouping texts together"):
tokenized_datasets = tokenized_datasets.map( if not data_args.streaming:
group_texts, tokenized_datasets = tokenized_datasets.map(
batched=True, group_texts,
num_proc=data_args.preprocessing_num_workers, batched=True,
load_from_cache_file=not data_args.overwrite_cache, num_proc=data_args.preprocessing_num_workers,
desc=f"Grouping texts in chunks of {max_seq_length}", load_from_cache_file=not data_args.overwrite_cache,
) desc=f"Grouping texts in chunks of {max_seq_length}",
)
else:
tokenized_datasets = tokenized_datasets.map(
group_texts,
batched=True,
)
if training_args.do_train: if training_args.do_train:
if "train" not in tokenized_datasets: if "train" not in tokenized_datasets:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment