Unverified Commit dd522da0 authored by vblagoje's avatar vblagoje Committed by GitHub
Browse files

Fix PL token classification examples (#6682)

parent a5737779
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-train.tsv?attredirects=0&d=1' \ ## The relevant files are currently on a shared Google
## drive at https://drive.google.com/drive/folders/1kC0I2UGl2ltrluI9NqDjaQJGw5iliw_J
## Monitor for changes and eventually migrate to nlp dataset
curl -L 'https://drive.google.com/uc?export=download&id=1Jjhbal535VVz2ap4v4r_rN1UEHTdLK5P' \
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > train.txt.tmp | grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > train.txt.tmp
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-dev.tsv?attredirects=0&d=1' \ curl -L 'https://drive.google.com/uc?export=download&id=1ZfRcQThdtAR5PPRjIDtrVP7BtXSCUBbm' \
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > dev.txt.tmp | grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > dev.txt.tmp
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-test.tsv?attredirects=0&d=1' \ curl -L 'https://drive.google.com/uc?export=download&id=1u9mb7kNJHWQCWyweMDRMuTFoOHOfeBTH' \
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > test.txt.tmp | grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > test.txt.tmp
export MAX_LENGTH=128 export MAX_LENGTH=128
......
...@@ -3,11 +3,14 @@ ...@@ -3,11 +3,14 @@
# for seqeval metrics import # for seqeval metrics import
pip install -r ../requirements.txt pip install -r ../requirements.txt
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-train.tsv?attredirects=0&d=1' \ ## The relevant files are currently on a shared Google
## drive at https://drive.google.com/drive/folders/1kC0I2UGl2ltrluI9NqDjaQJGw5iliw_J
## Monitor for changes and eventually migrate to nlp dataset
curl -L 'https://drive.google.com/uc?export=download&id=1Jjhbal535VVz2ap4v4r_rN1UEHTdLK5P' \
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > train.txt.tmp | grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > train.txt.tmp
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-dev.tsv?attredirects=0&d=1' \ curl -L 'https://drive.google.com/uc?export=download&id=1ZfRcQThdtAR5PPRjIDtrVP7BtXSCUBbm' \
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > dev.txt.tmp | grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > dev.txt.tmp
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-test.tsv?attredirects=0&d=1' \ curl -L 'https://drive.google.com/uc?export=download&id=1u9mb7kNJHWQCWyweMDRMuTFoOHOfeBTH' \
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > test.txt.tmp | grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > test.txt.tmp
export MAX_LENGTH=128 export MAX_LENGTH=128
...@@ -29,7 +32,6 @@ mkdir -p $OUTPUT_DIR ...@@ -29,7 +32,6 @@ mkdir -p $OUTPUT_DIR
export PYTHONPATH="../":"${PYTHONPATH}" export PYTHONPATH="../":"${PYTHONPATH}"
python3 run_pl_ner.py --data_dir ./ \ python3 run_pl_ner.py --data_dir ./ \
--model_type bert \
--labels ./labels.txt \ --labels ./labels.txt \
--model_name_or_path $BERT_MODEL \ --model_name_or_path $BERT_MODEL \
--output_dir $OUTPUT_DIR \ --output_dir $OUTPUT_DIR \
...@@ -37,5 +39,6 @@ python3 run_pl_ner.py --data_dir ./ \ ...@@ -37,5 +39,6 @@ python3 run_pl_ner.py --data_dir ./ \
--num_train_epochs $NUM_EPOCHS \ --num_train_epochs $NUM_EPOCHS \
--train_batch_size $BATCH_SIZE \ --train_batch_size $BATCH_SIZE \
--seed $SEED \ --seed $SEED \
--gpus 1 \
--do_train \ --do_train \
--do_predict --do_predict
...@@ -86,7 +86,7 @@ class NERTransformer(BaseTransformer): ...@@ -86,7 +86,7 @@ class NERTransformer(BaseTransformer):
logger.info("Saving features into cached file %s", cached_features_file) logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file) torch.save(features, cached_features_file)
def get_dataloader(self, mode: int, batch_size: int) -> DataLoader: def get_dataloader(self, mode: int, batch_size: int, shuffle: bool = False) -> DataLoader:
"Load datasets. Called after prepare data." "Load datasets. Called after prepare data."
cached_features_file = self._feature_file(mode) cached_features_file = self._feature_file(mode)
logger.info("Loading features from cached file %s", cached_features_file) logger.info("Loading features from cached file %s", cached_features_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment