Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1c7253cc
Commit
1c7253cc
authored
Nov 11, 2019
by
Stefan Schweter
Browse files
modeling: add DistilBertForTokenClassification implementation
parent
1c542df7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
73 additions
and
0 deletions
+73
-0
transformers/modeling_distilbert.py
transformers/modeling_distilbert.py
+73
-0
No files found.
transformers/modeling_distilbert.py
View file @
1c7253cc
...
...
@@ -30,6 +30,7 @@ import numpy as np
import
torch
import
torch.nn
as
nn
from
torch.nn
import
CrossEntropyLoss
from
.modeling_utils
import
PreTrainedModel
,
prune_linear_layer
from
.configuration_distilbert
import
DistilBertConfig
...
...
@@ -702,3 +703,75 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# (loss), start_logits, end_logits, (hidden_states), (attentions)
@
add_start_docstrings
(
"""DistilBert Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """
,
DISTILBERT_START_DOCSTRING
,
DISTILBERT_INPUTS_DOCSTRING
)
class
DistilBertForTokenClassification
(
DistilBertPreTrainedModel
):
r
"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForTokenClassification.from_pretrained('distilbert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
def
__init__
(
self
,
config
):
super
(
DistilBertForTokenClassification
,
self
).
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
distilbert
=
DistilBertModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
init_weights
()
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
):
outputs
=
self
.
distilbert
(
input_ids
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
)
sequence_output
=
outputs
[
0
]
sequence_output
=
self
.
dropout
(
sequence_output
)
logits
=
self
.
classifier
(
sequence_output
)
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
# Only keep active parts of the loss
if
attention_mask
is
not
None
:
active_loss
=
attention_mask
.
view
(
-
1
)
==
1
active_logits
=
logits
.
view
(
-
1
,
self
.
num_labels
)[
active_loss
]
active_labels
=
labels
.
view
(
-
1
)[
active_loss
]
loss
=
loss_fct
(
active_logits
,
active_labels
)
else
:
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), scores, (hidden_states), (attentions)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment