Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2b07b9e5
"projects/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "3cf1ea1f5bfaf62c356c9a1a12077fab2edabbc6"
Commit
2b07b9e5
authored
Nov 11, 2019
by
Stefan Schweter
Browse files
examples: add DistilBert support for NER fine-tuning
parent
1806eabf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
6 deletions
+10
-6
examples/run_ner.py
examples/run_ner.py
+10
-6
No files found.
examples/run_ner.py
View file @
2b07b9e5
...
...
@@ -36,16 +36,18 @@ from utils_ner import convert_examples_to_features, get_labels, read_examples_fr
from
transformers
import
AdamW
,
WarmupLinearSchedule
from
transformers
import
WEIGHTS_NAME
,
BertConfig
,
BertForTokenClassification
,
BertTokenizer
from
transformers
import
RobertaConfig
,
RobertaForTokenClassification
,
RobertaTokenizer
from
transformers
import
DistilBertConfig
,
DistilBertForTokenClassification
,
DistilBertTokenizer
logger
=
logging
.
getLogger
(
__name__
)
ALL_MODELS
=
sum
(
(
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
for
conf
in
(
BertConfig
,
RobertaConfig
)),
(
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
for
conf
in
(
BertConfig
,
RobertaConfig
,
DistilBertConfig
)),
())
MODEL_CLASSES
=
{
"bert"
:
(
BertConfig
,
BertForTokenClassification
,
BertTokenizer
),
"roberta"
:
(
RobertaConfig
,
RobertaForTokenClassification
,
RobertaTokenizer
)
"roberta"
:
(
RobertaConfig
,
RobertaForTokenClassification
,
RobertaTokenizer
),
"distilbert"
:
(
DistilBertConfig
,
DistilBertForTokenClassification
,
DistilBertTokenizer
)
}
...
...
@@ -121,9 +123,10 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
inputs
=
{
"input_ids"
:
batch
[
0
],
"attention_mask"
:
batch
[
1
],
"token_type_ids"
:
batch
[
2
]
if
args
.
model_type
in
[
"bert"
,
"xlnet"
]
else
None
,
# XLM and RoBERTa don"t use segment_ids
"labels"
:
batch
[
3
]}
if
args
.
model_type
!=
"distilbert"
:
inputs
[
"token_type_ids"
]:
batch
[
2
]
if
args
.
model_type
in
[
"bert"
,
"xlnet"
]
else
None
# XLM and RoBERTa don"t use segment_ids
outputs
=
model
(
**
inputs
)
loss
=
outputs
[
0
]
# model outputs are always tuple in pytorch-transformers (see doc)
...
...
@@ -206,9 +209,9 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
with
torch
.
no_grad
():
inputs
=
{
"input_ids"
:
batch
[
0
],
"attention_mask"
:
batch
[
1
],
"token_type_ids"
:
batch
[
2
]
if
args
.
model_type
in
[
"bert"
,
"xlnet"
]
else
None
,
# XLM and RoBERTa don"t use segment_ids
"labels"
:
batch
[
3
]}
if
args
.
model_type
!=
"distilbert"
:
inputs
[
"token_type_ids"
]:
batch
[
2
]
if
args
.
model_type
in
[
"bert"
,
"xlnet"
]
else
None
# XLM and RoBERTa don"t use segment_ids
outputs
=
model
(
**
inputs
)
tmp_eval_loss
,
logits
=
outputs
[:
2
]
...
...
@@ -520,3 +523,4 @@ def main():
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment