Unverified Commit 4046e66e authored by Stefan Schweter's avatar Stefan Schweter Committed by GitHub
Browse files

examples: only use keep_linebreaks when reading TXT files (#13320)

* examples: only use keep_linebreaks when reading TXT files for all CLM examples

* examples: only use keep_linebreaks when reading TXT files for all CLM examples

* examples: only use keep_linebreaks when reading TXT files for all CLM examples
parent b6f332ec
...@@ -157,7 +157,7 @@ class DataTrainingArguments: ...@@ -157,7 +157,7 @@ class DataTrainingArguments:
metadata={"help": "The number of processes to use for the preprocessing."}, metadata={"help": "The number of processes to use for the preprocessing."},
) )
keep_linebreaks: bool = field( keep_linebreaks: bool = field(
default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."} default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
) )
def __post_init__(self): def __post_init__(self):
...@@ -305,6 +305,7 @@ def main(): ...@@ -305,6 +305,7 @@ def main():
) )
else: else:
data_files = {} data_files = {}
dataset_args = {}
if data_args.train_file is not None: if data_args.train_file is not None:
data_files["train"] = data_args.train_file data_files["train"] = data_args.train_file
if data_args.validation_file is not None: if data_args.validation_file is not None:
...@@ -312,22 +313,23 @@ def main(): ...@@ -312,22 +313,23 @@ def main():
extension = data_args.train_file.split(".")[-1] extension = data_args.train_file.split(".")[-1]
if extension == "txt": if extension == "txt":
extension = "text" extension = "text"
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args)
if "validation" not in dataset.keys(): if "validation" not in dataset.keys():
dataset["validation"] = load_dataset( dataset["validation"] = load_dataset(
extension, extension,
keep_linebreaks=data_args.keep_linebreaks,
data_files=data_files, data_files=data_files,
split=f"train[:{data_args.validation_split_percentage}%]", split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
**dataset_args,
) )
dataset["train"] = load_dataset( dataset["train"] = load_dataset(
extension, extension,
keep_linebreaks=data_args.keep_linebreaks,
data_files=data_files, data_files=data_files,
split=f"train[{data_args.validation_split_percentage}%:]", split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
**dataset_args,
) )
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html. # https://huggingface.co/docs/datasets/loading_datasets.html.
......
...@@ -173,7 +173,7 @@ class DataTrainingArguments: ...@@ -173,7 +173,7 @@ class DataTrainingArguments:
metadata={"help": "The number of processes to use for the preprocessing."}, metadata={"help": "The number of processes to use for the preprocessing."},
) )
keep_linebreaks: bool = field( keep_linebreaks: bool = field(
default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."} default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
) )
def __post_init__(self): def __post_init__(self):
...@@ -269,6 +269,7 @@ def main(): ...@@ -269,6 +269,7 @@ def main():
) )
else: else:
data_files = {} data_files = {}
dataset_args = {}
if data_args.train_file is not None: if data_args.train_file is not None:
data_files["train"] = data_args.train_file data_files["train"] = data_args.train_file
if data_args.validation_file is not None: if data_args.validation_file is not None:
...@@ -280,22 +281,23 @@ def main(): ...@@ -280,22 +281,23 @@ def main():
) )
if extension == "txt": if extension == "txt":
extension = "text" extension = "text"
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args)
# If no validation data is there, validation_split_percentage will be used to divide the dataset. # If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys(): if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset( raw_datasets["validation"] = load_dataset(
extension, extension,
keep_linebreaks=data_args.keep_linebreaks,
data_files=data_files, data_files=data_files,
split=f"train[:{data_args.validation_split_percentage}%]", split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
**dataset_args,
) )
raw_datasets["train"] = load_dataset( raw_datasets["train"] = load_dataset(
extension, extension,
keep_linebreaks=data_args.keep_linebreaks,
data_files=data_files, data_files=data_files,
split=f"train[{data_args.validation_split_percentage}%:]", split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
**dataset_args,
) )
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
......
...@@ -174,7 +174,7 @@ def parse_args(): ...@@ -174,7 +174,7 @@ def parse_args():
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
) )
parser.add_argument( parser.add_argument(
"--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using CSV/JSON/TXT files." "--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files."
) )
args = parser.parse_args() args = parser.parse_args()
...@@ -248,6 +248,7 @@ def main(): ...@@ -248,6 +248,7 @@ def main():
) )
else: else:
data_files = {} data_files = {}
dataset_args = {}
if args.train_file is not None: if args.train_file is not None:
data_files["train"] = args.train_file data_files["train"] = args.train_file
if args.validation_file is not None: if args.validation_file is not None:
...@@ -255,20 +256,21 @@ def main(): ...@@ -255,20 +256,21 @@ def main():
extension = args.train_file.split(".")[-1] extension = args.train_file.split(".")[-1]
if extension == "txt": if extension == "txt":
extension = "text" extension = "text"
raw_datasets = load_dataset(extension, data_files=data_files) dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
# If no validation data is there, validation_split_percentage will be used to divide the dataset. # If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys(): if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset( raw_datasets["validation"] = load_dataset(
extension, extension,
keep_linebreaks=not args.no_keep_linebreaks,
data_files=data_files, data_files=data_files,
split=f"train[:{args.validation_split_percentage}%]", split=f"train[:{args.validation_split_percentage}%]",
**dataset_args,
) )
raw_datasets["train"] = load_dataset( raw_datasets["train"] = load_dataset(
extension, extension,
keep_linebreaks=not args.no_keep_linebreaks,
data_files=data_files, data_files=data_files,
split=f"train[{args.validation_split_percentage}%:]", split=f"train[{args.validation_split_percentage}%:]",
**dataset_args,
) )
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
......
...@@ -187,7 +187,7 @@ class DataTrainingArguments: ...@@ -187,7 +187,7 @@ class DataTrainingArguments:
}, },
) )
keep_linebreaks: bool = field( keep_linebreaks: bool = field(
default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."} default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
) )
def __post_init__(self): def __post_init__(self):
...@@ -321,6 +321,7 @@ def main(): ...@@ -321,6 +321,7 @@ def main():
) )
else: else:
data_files = {} data_files = {}
dataset_args = {}
if data_args.train_file is not None: if data_args.train_file is not None:
data_files["train"] = data_args.train_file data_files["train"] = data_args.train_file
if data_args.validation_file is not None: if data_args.validation_file is not None:
...@@ -328,7 +329,8 @@ def main(): ...@@ -328,7 +329,8 @@ def main():
extension = data_args.train_file.split(".")[-1] extension = data_args.train_file.split(".")[-1]
if extension == "txt": if extension == "txt":
extension = "text" extension = "text"
raw_datasets = load_dataset(extension, keep_linebreaks=data_args.keep_linebreaks, data_files=data_files) dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html. # https://huggingface.co/docs/datasets/loading_datasets.html.
# endregion # endregion
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment