Unverified Commit 04e25c62 authored by Philipp Schmid's avatar Philipp Schmid Committed by GitHub
Browse files

add `dataset_name` to data_args and added accuracy metric (#11760)

* add `dataset_name` to data_args and added accuracy metric

* added documentation for dataset_name

* spelling correction
parent fd3b12e8
...@@ -22,8 +22,8 @@ Based on the script [`run_glue.py`](https://github.com/huggingface/transformers/ ...@@ -22,8 +22,8 @@ Based on the script [`run_glue.py`](https://github.com/huggingface/transformers/
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models) Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models)
and can also be used for your own data in a csv or a JSON file (the script might need some tweaks in that case, refer and can also be used for a dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file
to the comments inside for help). (the script might need some tweaks in that case, refer to the comments inside for help).
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them: GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
...@@ -64,6 +64,22 @@ single Titan RTX was used): ...@@ -64,6 +64,22 @@ single Titan RTX was used):
Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website. website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.
The following example fine-tunes BERT on the `imdb` dataset hosted on our [hub](https://huggingface.co/datasets):
```bash
python run_glue.py \
--model_name_or_path bert-base-cased \
--dataset_name imdb \
--do_train \
--do_predict \
--max_seq_length 128 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--output_dir /tmp/imdb/
```
### Mixed precision training ### Mixed precision training
If you have a GPU with mixed precision capabilities (architecture Pascal or more recent), you can use mixed precision If you have a GPU with mixed precision capabilities (architecture Pascal or more recent), you can use mixed precision
......
...@@ -76,6 +76,12 @@ class DataTrainingArguments: ...@@ -76,6 +76,12 @@ class DataTrainingArguments:
default=None, default=None,
metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
) )
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
max_seq_length: int = field( max_seq_length: int = field(
default=128, default=128,
metadata={ metadata={
...@@ -127,8 +133,10 @@ class DataTrainingArguments: ...@@ -127,8 +133,10 @@ class DataTrainingArguments:
self.task_name = self.task_name.lower() self.task_name = self.task_name.lower()
if self.task_name not in task_to_keys.keys(): if self.task_name not in task_to_keys.keys():
raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys())) raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
elif self.dataset_name is not None:
pass
elif self.train_file is None or self.validation_file is None: elif self.train_file is None or self.validation_file is None:
raise ValueError("Need either a GLUE task or a training/validation file.") raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.")
else: else:
train_extension = self.train_file.split(".")[-1] train_extension = self.train_file.split(".")[-1]
assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
...@@ -240,6 +248,9 @@ def main(): ...@@ -240,6 +248,9 @@ def main():
if data_args.task_name is not None: if data_args.task_name is not None:
# Downloading and loading a dataset from the hub. # Downloading and loading a dataset from the hub.
datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir) datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)
elif data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
else: else:
# Loading a dataset from your local files. # Loading a dataset from your local files.
# CSV/JSON training and evaluation files are needed. # CSV/JSON training and evaluation files are needed.
...@@ -408,8 +419,8 @@ def main(): ...@@ -408,8 +419,8 @@ def main():
# Get the metric function # Get the metric function
if data_args.task_name is not None: if data_args.task_name is not None:
metric = load_metric("glue", data_args.task_name) metric = load_metric("glue", data_args.task_name)
# TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from else:
# compute_metrics metric = load_metric("accuracy")
# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float. # predictions and label_ids field) and has to return a dictionary string to float.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment