Unverified Commit f2ffcaf4 authored by Tommy Chiang's avatar Tommy Chiang Committed by GitHub
Browse files

[Examples] Check key exists in datasets first (#11503)

parent ba0d50f2
...@@ -347,9 +347,9 @@ def main(): ...@@ -347,9 +347,9 @@ def main():
return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()} return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}
if training_args.do_train: if training_args.do_train:
train_dataset = datasets["train"]
if "train" not in datasets: if "train" not in datasets:
raise ValueError("--do_train requires a train dataset") raise ValueError("--do_train requires a train dataset")
train_dataset = datasets["train"]
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples)) train_dataset = train_dataset.select(range(data_args.max_train_samples))
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
......
...@@ -422,9 +422,9 @@ def main(): ...@@ -422,9 +422,9 @@ def main():
return model_inputs return model_inputs
if training_args.do_train: if training_args.do_train:
train_dataset = datasets["train"]
if "train" not in datasets: if "train" not in datasets:
raise ValueError("--do_train requires a train dataset") raise ValueError("--do_train requires a train dataset")
train_dataset = datasets["train"]
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples)) train_dataset = train_dataset.select(range(data_args.max_train_samples))
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
......
...@@ -416,9 +416,9 @@ def main(): ...@@ -416,9 +416,9 @@ def main():
return model_inputs return model_inputs
if training_args.do_train: if training_args.do_train:
train_dataset = datasets["train"]
if "train" not in datasets: if "train" not in datasets:
raise ValueError("--do_train requires a train dataset") raise ValueError("--do_train requires a train dataset")
train_dataset = datasets["train"]
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples)) train_dataset = train_dataset.select(range(data_args.max_train_samples))
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment