Unverified Commit b58f67f2 authored by aihao's avatar aihao Committed by GitHub
Browse files

update (#7067)



* add data_dir parameter to load_dataset

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 8ac6de96
...@@ -571,9 +571,6 @@ def parse_args(input_args=None): ...@@ -571,9 +571,6 @@ def parse_args(input_args=None):
if args.dataset_name is None and args.train_data_dir is None: if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
if args.dataset_name is not None and args.train_data_dir is not None:
raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
...@@ -615,6 +612,7 @@ def make_train_dataset(args, tokenizer, accelerator): ...@@ -615,6 +612,7 @@ def make_train_dataset(args, tokenizer, accelerator):
args.dataset_name, args.dataset_name,
args.dataset_config_name, args.dataset_config_name,
cache_dir=args.cache_dir, cache_dir=args.cache_dir,
data_dir=args.train_data_dir,
) )
else: else:
if args.train_data_dir is not None: if args.train_data_dir is not None:
......
...@@ -598,9 +598,6 @@ def parse_args(input_args=None): ...@@ -598,9 +598,6 @@ def parse_args(input_args=None):
if args.dataset_name is None and args.train_data_dir is None: if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
if args.dataset_name is not None and args.train_data_dir is not None:
raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
...@@ -642,6 +639,7 @@ def get_train_dataset(args, accelerator): ...@@ -642,6 +639,7 @@ def get_train_dataset(args, accelerator):
args.dataset_name, args.dataset_name,
args.dataset_config_name, args.dataset_config_name,
cache_dir=args.cache_dir, cache_dir=args.cache_dir,
data_dir=args.train_data_dir,
) )
else: else:
if args.train_data_dir is not None: if args.train_data_dir is not None:
......
...@@ -483,7 +483,6 @@ def parse_args(input_args=None): ...@@ -483,7 +483,6 @@ def parse_args(input_args=None):
# Sanity checks # Sanity checks
if args.dataset_name is None and args.train_data_dir is None: if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("Need either a dataset name or a training folder.") raise ValueError("Need either a dataset name or a training folder.")
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
...@@ -824,9 +823,7 @@ def main(args): ...@@ -824,9 +823,7 @@ def main(args):
if args.dataset_name is not None: if args.dataset_name is not None:
# Downloading and loading a dataset from the hub. # Downloading and loading a dataset from the hub.
dataset = load_dataset( dataset = load_dataset(
args.dataset_name, args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
args.dataset_config_name,
cache_dir=args.cache_dir,
) )
else: else:
data_files = {} data_files = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment