"tools/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "7b1f6a850f9b7507d7cb6afc1d576e3bb5b265d0"
Commit 2b07b9e5 authored by Stefan Schweter's avatar Stefan Schweter
Browse files

examples: add DistilBert support for NER fine-tuning

parent 1806eabf
...@@ -36,16 +36,18 @@ from utils_ner import convert_examples_to_features, get_labels, read_examples_fr ...@@ -36,16 +36,18 @@ from utils_ner import convert_examples_to_features, get_labels, read_examples_fr
from transformers import AdamW, WarmupLinearSchedule from transformers import AdamW, WarmupLinearSchedule
from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer
from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer
from transformers import DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum( ALL_MODELS = sum(
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig)), (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)),
()) ())
MODEL_CLASSES = { MODEL_CLASSES = {
"bert": (BertConfig, BertForTokenClassification, BertTokenizer), "bert": (BertConfig, BertForTokenClassification, BertTokenizer),
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer) "roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer)
} }
...@@ -121,9 +123,10 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): ...@@ -121,9 +123,10 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
inputs = {"input_ids": batch[0], inputs = {"input_ids": batch[0],
"attention_mask": batch[1], "attention_mask": batch[1],
"token_type_ids": batch[2] if args.model_type in ["bert", "xlnet"] else None,
# XLM and RoBERTa don"t use segment_ids
"labels": batch[3]} "labels": batch[3]}
if args.model_type != "distilbert":
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
outputs = model(**inputs) outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
...@@ -206,9 +209,9 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix="" ...@@ -206,9 +209,9 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
with torch.no_grad(): with torch.no_grad():
inputs = {"input_ids": batch[0], inputs = {"input_ids": batch[0],
"attention_mask": batch[1], "attention_mask": batch[1],
"token_type_ids": batch[2] if args.model_type in ["bert", "xlnet"] else None,
# XLM and RoBERTa don"t use segment_ids
"labels": batch[3]} "labels": batch[3]}
if args.model_type != "distilbert":
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
outputs = model(**inputs) outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2] tmp_eval_loss, logits = outputs[:2]
...@@ -520,3 +523,4 @@ def main(): ...@@ -520,3 +523,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment