Unverified Commit f176e707 authored by Kelvin's avatar Kelvin Committed by GitHub
Browse files

The input training data files (multiple files in glob format). (#7717)

Very often splitting large files to smaller files can prevent tokenizer going out of memory in environment like Colab that does not have swap memory
parent 34fcfb44
...@@ -24,8 +24,11 @@ import logging ...@@ -24,8 +24,11 @@ import logging
import math import math
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from glob import glob
from typing import Optional from typing import Optional
from torch.utils.data import ConcatDataset
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING, MODEL_WITH_LM_HEAD_MAPPING,
...@@ -87,6 +90,12 @@ class DataTrainingArguments: ...@@ -87,6 +90,12 @@ class DataTrainingArguments:
train_data_file: Optional[str] = field( train_data_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a text file)."} default=None, metadata={"help": "The input training data file (a text file)."}
) )
train_data_files: Optional[str] = field(
default=None, metadata={
"help": "The input training data files (multiple files in glob format). "
"Very often splitting large files to smaller files can prevent tokenizer going out of memory"
}
)
eval_data_file: Optional[str] = field( eval_data_file: Optional[str] = field(
default=None, default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
...@@ -131,17 +140,24 @@ def get_dataset( ...@@ -131,17 +140,24 @@ def get_dataset(
evaluate: bool = False, evaluate: bool = False,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
): ):
file_path = args.eval_data_file if evaluate else args.train_data_file def _dataset(file_path):
if args.line_by_line: if args.line_by_line:
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size) return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
else:
return TextDataset(
tokenizer=tokenizer,
file_path=file_path,
block_size=args.block_size,
overwrite_cache=args.overwrite_cache,
cache_dir=cache_dir,
)
if evaluate:
return _dataset(args.eval_data_file)
elif args.train_data_files:
return ConcatDataset([_dataset(f) for f in glob(args.train_data_files)])
else: else:
return TextDataset( return _dataset(args.train_data_file)
tokenizer=tokenizer,
file_path=file_path,
block_size=args.block_size,
overwrite_cache=args.overwrite_cache,
cache_dir=cache_dir,
)
def main(): def main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment