Unverified Commit 1e616c0a authored by Stefan Schweter's avatar Stefan Schweter Committed by GitHub
Browse files

NER: parse args from .args file or JSON (#4110)

* ner: parse args from .args file or JSON

* examples: mention json-based configuration file support for run_ner script
parent abb1fa3f
...@@ -79,6 +79,29 @@ python3 run_ner.py --data_dir ./ \ ...@@ -79,6 +79,29 @@ python3 run_ner.py --data_dir ./ \
If your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets. If your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets.
### JSON-based configuration file
Instead of passing all parameters via commandline arguments, the `run_ner.py` script also supports reading parameters from a json-based configuration file:
```json
{
"data_dir": ".",
"labels": "./labels.txt",
"model_name_or_path": "bert-base-multilingual-cased",
"output_dir": "germeval-model",
"max_seq_length": 128,
"num_train_epochs": 3,
"per_gpu_train_batch_size": 32,
"save_steps": 750,
"seed": 1,
"do_train": true,
"do_eval": true,
"do_predict": true
}
```
It must be saved with a `.json` extension and can be used by running `python3 run_ner.py config.json`.
#### Evaluation #### Evaluation
Evaluation on development dataset outputs the following for our example: Evaluation on development dataset outputs the following for our example:
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import logging import logging
import os import os
import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
...@@ -94,6 +95,11 @@ def main(): ...@@ -94,6 +95,11 @@ def main():
# We now keep distinct sets of args, for a cleaner separation of concerns. # We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( if (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment