Unverified Commit a73281e3 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[examples] max samples can't be bigger than the len of dataset (#16501)

* [examples] max samples can't be bigger than then len of dataset

* do tf and flax
parent c4deb7b3
......@@ -415,9 +415,11 @@ def main():
train_dataset = train_dataset.select(train_indices)
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
......
......@@ -456,9 +456,11 @@ def main():
train_dataset = train_dataset.select(train_indices)
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
......
......@@ -369,7 +369,8 @@ def main():
train_dataset = raw_datasets["train"]
non_label_columns = [feature for feature in train_dataset.features if feature not in ("label", "labels")]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
with training_args.main_process_first(desc="train dataset map pre-processing"):
train_dataset = train_dataset.map(
preprocess_function,
......@@ -385,7 +386,8 @@ def main():
if not training_args.do_train:
non_label_columns = [feature for feature in eval_dataset.features if feature not in ("label", "labels")]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_dataset.map(
preprocess_function,
......
......@@ -438,7 +438,8 @@ def main():
train_dataset = datasets["train"]
if data_args.max_train_samples is not None:
# We will select sample from whole data if agument is specified
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
# Create train feature from dataset
train_dataset = train_dataset.map(
prepare_train_features,
......@@ -449,7 +450,8 @@ def main():
)
if data_args.max_train_samples is not None:
# Number of samples might increase during Feature Creation, We select only specified max samples
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
processed_datasets["train"] = train_dataset
# Validation preprocessing
......@@ -505,7 +507,8 @@ def main():
eval_examples = datasets["validation"]
if data_args.max_eval_samples is not None:
# We will select sample from whole data
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_examples), data_args.max_eval_samples)
eval_examples = eval_examples.select(range(max_eval_samples))
# Validation Feature Creation
eval_dataset = eval_examples.map(
prepare_validation_features,
......@@ -516,7 +519,8 @@ def main():
)
if data_args.max_eval_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
processed_datasets["validation"] = eval_dataset
if training_args.do_predict:
......@@ -536,7 +540,8 @@ def main():
)
if data_args.max_predict_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
predict_dataset = predict_dataset.select(range(max_predict_samples))
processed_datasets["test"] = predict_dataset
# endregion
......
......@@ -490,7 +490,8 @@ def main():
raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
with training_args.main_process_first(desc="train dataset map pre-processing"):
train_dataset = train_dataset.map(
preprocess_function,
......@@ -509,7 +510,8 @@ def main():
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = raw_datasets["validation"]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_dataset.map(
preprocess_function,
......
......@@ -445,7 +445,8 @@ def main():
raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
with training_args.main_process_first(desc="train dataset map pre-processing"):
train_dataset = train_dataset.map(
preprocess_function,
......@@ -464,7 +465,8 @@ def main():
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = raw_datasets["validation"]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_dataset.map(
preprocess_function,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment