Unverified Commit f176e707 authored by Kelvin's avatar Kelvin Committed by GitHub
Browse files

The input training data files (multiple files in glob format). (#7717)

Very often splitting large files to smaller files can prevent tokenizer going out of memory in environment like Colab that does not have swap memory
parent 34fcfb44
......@@ -24,8 +24,11 @@ import logging
import math
import os
from dataclasses import dataclass, field
from glob import glob
from typing import Optional
from torch.utils.data import ConcatDataset
from transformers import (
CONFIG_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
......@@ -87,6 +90,12 @@ class DataTrainingArguments:
train_data_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a text file)."}
)
train_data_files: Optional[str] = field(
default=None, metadata={
"help": "The input training data files (multiple files in glob format). "
"Very often splitting large files to smaller files can prevent tokenizer going out of memory"
}
)
eval_data_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
......@@ -131,7 +140,7 @@ def get_dataset(
evaluate: bool = False,
cache_dir: Optional[str] = None,
):
file_path = args.eval_data_file if evaluate else args.train_data_file
def _dataset(file_path):
if args.line_by_line:
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
else:
......@@ -143,6 +152,13 @@ def get_dataset(
cache_dir=cache_dir,
)
if evaluate:
return _dataset(args.eval_data_file)
elif args.train_data_files:
return ConcatDataset([_dataset(f) for f in glob(args.train_data_files)])
else:
return _dataset(args.train_data_file)
def main():
# See all possible arguments in src/transformers/training_args.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment