Unverified Commit 39fa4009 authored by Klaus Hipp's avatar Klaus Hipp Committed by GitHub
Browse files

Fix input data file extension in examples (#28741)

parent 5649c0cb
...@@ -320,9 +320,10 @@ def main(): ...@@ -320,9 +320,10 @@ def main():
data_files = {} data_files = {}
if data_args.train_file is not None: if data_args.train_file is not None:
data_files["train"] = data_args.train_file data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None: if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1] extension = data_args.validation_file.split(".")[-1]
raw_datasets = load_dataset( raw_datasets = load_dataset(
extension, extension,
data_files=data_files, data_files=data_files,
......
...@@ -260,9 +260,10 @@ def main(): ...@@ -260,9 +260,10 @@ def main():
data_files = {} data_files = {}
if data_args.train_file is not None: if data_args.train_file is not None:
data_files["train"] = data_args.train_file data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None: if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1] extension = data_args.validation_file.split(".")[-1]
raw_datasets = load_dataset( raw_datasets = load_dataset(
extension, extension,
data_files=data_files, data_files=data_files,
......
...@@ -730,9 +730,10 @@ def main(): ...@@ -730,9 +730,10 @@ def main():
data_files = {} data_files = {}
if args.train_file is not None: if args.train_file is not None:
data_files["train"] = args.train_file data_files["train"] = args.train_file
extension = args.train_file.split(".")[-1]
if args.validation_file is not None: if args.validation_file is not None:
data_files["validation"] = args.validation_file data_files["validation"] = args.validation_file
extension = args.train_file.split(".")[-1] extension = args.validation_file.split(".")[-1]
raw_datasets = load_dataset(extension, data_files=data_files) raw_datasets = load_dataset(extension, data_files=data_files)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets. # https://huggingface.co/docs/datasets/loading_datasets.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment