Unverified Commit 6f52fce6 authored by Allen Wang's avatar Allen Wang Committed by GitHub
Browse files

Fixes an issue in `text-classification` where MNLI eval/test datasets are not...

Fixes an issue in `text-classification` where MNLI eval/test datasets are not being preprocessed. (#10621)

* Fix MNLI tests

* Linter fix
parent 72d9e039
...@@ -374,17 +374,13 @@ def main(): ...@@ -374,17 +374,13 @@ def main():
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
return result return result
datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
if training_args.do_train: if training_args.do_train:
if "train" not in datasets: if "train" not in datasets:
raise ValueError("--do_train requires a train dataset") raise ValueError("--do_train requires a train dataset")
train_dataset = datasets["train"] train_dataset = datasets["train"]
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples)) train_dataset = train_dataset.select(range(data_args.max_train_samples))
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)
if training_args.do_eval: if training_args.do_eval:
if "validation" not in datasets and "validation_matched" not in datasets: if "validation" not in datasets and "validation_matched" not in datasets:
...@@ -392,11 +388,6 @@ def main(): ...@@ -392,11 +388,6 @@ def main():
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
if data_args.max_val_samples is not None: if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
if "test" not in datasets and "test_matched" not in datasets: if "test" not in datasets and "test_matched" not in datasets:
...@@ -404,15 +395,11 @@ def main(): ...@@ -404,15 +395,11 @@ def main():
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"] test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
if data_args.max_test_samples is not None: if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples)) test_dataset = test_dataset.select(range(data_args.max_test_samples))
test_dataset = test_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)
# Log a few random samples from the training set: # Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3): if training_args.do_train:
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
# Get the metric function # Get the metric function
if data_args.task_name is not None: if data_args.task_name is not None:
...@@ -447,7 +434,7 @@ def main(): ...@@ -447,7 +434,7 @@ def main():
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=train_dataset, train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None, eval_dataset=eval_dataset if training_args.do_eval else None,
compute_metrics=compute_metrics, compute_metrics=compute_metrics,
tokenizer=tokenizer, tokenizer=tokenizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment