"git@developer.sourcefind.cn:chenpangpang/Real-ESRGAN.git" did not exist on "5fb982294e72ca6643cd594118aa66f396d503ec"
Unverified Commit a73281e3 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[examples] max samples can't be bigger than the len of dataset (#16501)

* [examples] max samples can't be bigger than then len of dataset

* do tf and flax
parent c4deb7b3
...@@ -415,9 +415,11 @@ def main(): ...@@ -415,9 +415,11 @@ def main():
train_dataset = train_dataset.select(train_indices) train_dataset = train_dataset.select(train_indices)
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples)) max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
# Log a few random samples from the training set: # Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3): for index in random.sample(range(len(train_dataset)), 3):
......
...@@ -456,9 +456,11 @@ def main(): ...@@ -456,9 +456,11 @@ def main():
train_dataset = train_dataset.select(train_indices) train_dataset = train_dataset.select(train_indices)
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples)) max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
# Log a few random samples from the training set: # Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3): for index in random.sample(range(len(train_dataset)), 3):
......
...@@ -369,7 +369,8 @@ def main(): ...@@ -369,7 +369,8 @@ def main():
train_dataset = raw_datasets["train"] train_dataset = raw_datasets["train"]
non_label_columns = [feature for feature in train_dataset.features if feature not in ("label", "labels")] non_label_columns = [feature for feature in train_dataset.features if feature not in ("label", "labels")]
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples)) max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
with training_args.main_process_first(desc="train dataset map pre-processing"): with training_args.main_process_first(desc="train dataset map pre-processing"):
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
preprocess_function, preprocess_function,
...@@ -385,7 +386,8 @@ def main(): ...@@ -385,7 +386,8 @@ def main():
if not training_args.do_train: if not training_args.do_train:
non_label_columns = [feature for feature in eval_dataset.features if feature not in ("label", "labels")] non_label_columns = [feature for feature in eval_dataset.features if feature not in ("label", "labels")]
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
with training_args.main_process_first(desc="validation dataset map pre-processing"): with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
preprocess_function, preprocess_function,
......
...@@ -438,7 +438,8 @@ def main(): ...@@ -438,7 +438,8 @@ def main():
train_dataset = datasets["train"] train_dataset = datasets["train"]
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
# We will select sample from whole data if agument is specified # We will select sample from whole data if agument is specified
train_dataset = train_dataset.select(range(data_args.max_train_samples)) max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
# Create train feature from dataset # Create train feature from dataset
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
prepare_train_features, prepare_train_features,
...@@ -449,7 +450,8 @@ def main(): ...@@ -449,7 +450,8 @@ def main():
) )
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
# Number of samples might increase during Feature Creation, We select only specified max samples # Number of samples might increase during Feature Creation, We select only specified max samples
train_dataset = train_dataset.select(range(data_args.max_train_samples)) max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
processed_datasets["train"] = train_dataset processed_datasets["train"] = train_dataset
# Validation preprocessing # Validation preprocessing
...@@ -505,7 +507,8 @@ def main(): ...@@ -505,7 +507,8 @@ def main():
eval_examples = datasets["validation"] eval_examples = datasets["validation"]
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
# We will select sample from whole data # We will select sample from whole data
eval_examples = eval_examples.select(range(data_args.max_eval_samples)) max_eval_samples = min(len(eval_examples), data_args.max_eval_samples)
eval_examples = eval_examples.select(range(max_eval_samples))
# Validation Feature Creation # Validation Feature Creation
eval_dataset = eval_examples.map( eval_dataset = eval_examples.map(
prepare_validation_features, prepare_validation_features,
...@@ -516,7 +519,8 @@ def main(): ...@@ -516,7 +519,8 @@ def main():
) )
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again # During Feature creation dataset samples might increase, we will select required samples again
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
processed_datasets["validation"] = eval_dataset processed_datasets["validation"] = eval_dataset
if training_args.do_predict: if training_args.do_predict:
...@@ -536,7 +540,8 @@ def main(): ...@@ -536,7 +540,8 @@ def main():
) )
if data_args.max_predict_samples is not None: if data_args.max_predict_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again # During Feature creation dataset samples might increase, we will select required samples again
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
predict_dataset = predict_dataset.select(range(max_predict_samples))
processed_datasets["test"] = predict_dataset processed_datasets["test"] = predict_dataset
# endregion # endregion
......
...@@ -490,7 +490,8 @@ def main(): ...@@ -490,7 +490,8 @@ def main():
raise ValueError("--do_train requires a train dataset") raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets["train"] train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples)) max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
with training_args.main_process_first(desc="train dataset map pre-processing"): with training_args.main_process_first(desc="train dataset map pre-processing"):
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
preprocess_function, preprocess_function,
...@@ -509,7 +510,8 @@ def main(): ...@@ -509,7 +510,8 @@ def main():
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_dataset = raw_datasets["validation"] eval_dataset = raw_datasets["validation"]
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
with training_args.main_process_first(desc="validation dataset map pre-processing"): with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
preprocess_function, preprocess_function,
......
...@@ -445,7 +445,8 @@ def main(): ...@@ -445,7 +445,8 @@ def main():
raise ValueError("--do_train requires a train dataset") raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets["train"] train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples)) max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
with training_args.main_process_first(desc="train dataset map pre-processing"): with training_args.main_process_first(desc="train dataset map pre-processing"):
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
preprocess_function, preprocess_function,
...@@ -464,7 +465,8 @@ def main(): ...@@ -464,7 +465,8 @@ def main():
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_dataset = raw_datasets["validation"] eval_dataset = raw_datasets["validation"]
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
with training_args.main_process_first(desc="validation dataset map pre-processing"): with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
preprocess_function, preprocess_function,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment