Commit 0ae96ff8 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

BIG Reorganize examples (#4213)

* Created using Colaboratory

* [examples] reorganize files

* remove run_tpu_glue.py as superseded by TPU support in Trainer

* Bugfix: int, not tuple

* move files around
parent cafa6a9e
......@@ -27,7 +27,7 @@ export CURRENT_DIR=${PWD}
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access transformer_base.py
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
python3 run_pl_ner.py --data_dir ./ \
......
......@@ -9,7 +9,7 @@ from seqeval.metrics import f1_score, precision_score, recall_score
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, TensorDataset
from transformer_base import BaseTransformer, add_generic_args, generic_train
from lightning_base import BaseTransformer, add_generic_args, generic_train
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
......
......@@ -18,10 +18,10 @@ class ExamplesTests(unittest.TestCase):
testargs = """
--model_name distilbert-base-german-cased
--output_dir ./examples/tests_samples/temp_dir
--output_dir ./tests/fixtures/tests_samples/temp_dir
--overwrite_output_dir
--data_dir ./examples/tests_samples/GermEval
--labels ./examples/tests_samples/GermEval/labels.txt
--data_dir ./tests/fixtures/tests_samples/GermEval
--labels ./tests/fixtures/tests_samples/GermEval/labels.txt
--max_seq_length 128
--num_train_epochs 6
--logging_steps 1
......
......@@ -4,7 +4,7 @@
This model is a fine tuned RoBERTA model over STS-B.
It was trained with these params:
!python /content/transformers/examples/run_glue.py \
!python /content/transformers/examples/text-classification/run_glue.py \
--model_type roberta \
--model_name_or_path roberta-large \
--task_name STS-B \
......
......@@ -333,7 +333,7 @@ class Trainer:
total_train_batch_size = (
self.args.train_batch_size
* self.args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
)
logger.info("***** Running training *****")
logger.info(" Num examples = %d", num_examples)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment