Unverified Commit 6f52fce6 authored by Allen Wang's avatar Allen Wang Committed by GitHub
Browse files

Fixes an issue in `text-classification` where MNLI eval/test datasets are not...

Fixes an issue in `text-classification` where MNLI eval/test datasets are not being preprocessed. (#10621)

* Fix MNLI tests

* Linter fix
parent 72d9e039
......@@ -374,17 +374,13 @@ def main():
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
return result
datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
if training_args.do_train:
if "train" not in datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)
if training_args.do_eval:
if "validation" not in datasets and "validation_matched" not in datasets:
......@@ -392,11 +388,6 @@ def main():
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
if "test" not in datasets and "test_matched" not in datasets:
......@@ -404,13 +395,9 @@ def main():
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
test_dataset = test_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)
# Log a few random samples from the training set:
if training_args.do_train:
for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
......@@ -447,7 +434,7 @@ def main():
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
compute_metrics=compute_metrics,
tokenizer=tokenizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment