Unverified Commit 04e25c62 authored by Philipp Schmid's avatar Philipp Schmid Committed by GitHub
Browse files

add `dataset_name` to data_args and added accuracy metric (#11760)

* add `dataset_name` to data_args and added accuracy metric

* added documentation for dataset_name

* spelling correction
parent fd3b12e8
......@@ -22,8 +22,8 @@ Based on the script [`run_glue.py`](https://github.com/huggingface/transformers/
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models)
and can also be used for your own data in a csv or a JSON file (the script might need some tweaks in that case, refer
to the comments inside for help).
and can also be used for a dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file
(the script might need some tweaks in that case, refer to the comments inside for help).
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
......@@ -64,6 +64,22 @@ single Titan RTX was used):
Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.
The following example fine-tunes BERT on the `imdb` dataset hosted on our [hub](https://huggingface.co/datasets):
```bash
python run_glue.py \
--model_name_or_path bert-base-cased \
--dataset_name imdb \
--do_train \
--do_predict \
--max_seq_length 128 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--output_dir /tmp/imdb/
```
### Mixed precision training
If you have a GPU with mixed precision capabilities (architecture Pascal or more recent), you can use mixed precision
......
......@@ -76,6 +76,12 @@ class DataTrainingArguments:
default=None,
metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
)
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
max_seq_length: int = field(
default=128,
metadata={
......@@ -127,8 +133,10 @@ class DataTrainingArguments:
self.task_name = self.task_name.lower()
if self.task_name not in task_to_keys.keys():
raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
elif self.dataset_name is not None:
pass
elif self.train_file is None or self.validation_file is None:
raise ValueError("Need either a GLUE task or a training/validation file.")
raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.")
else:
train_extension = self.train_file.split(".")[-1]
assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
......@@ -240,6 +248,9 @@ def main():
if data_args.task_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)
elif data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
else:
# Loading a dataset from your local files.
# CSV/JSON training and evaluation files are needed.
......@@ -408,8 +419,8 @@ def main():
# Get the metric function
if data_args.task_name is not None:
metric = load_metric("glue", data_args.task_name)
# TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from
# compute_metrics
else:
metric = load_metric("accuracy")
# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment