Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e444648a
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ae98d4599179b299563b679dd33f8a86da12980d"
Unverified
Commit
e444648a
authored
May 28, 2020
by
Suraj Patil
Committed by
GitHub
May 28, 2020
Browse files
LongformerForTokenClassification (#4638)
parent
3cc2c2a1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
122 additions
and
0 deletions
+122
-0
src/transformers/__init__.py
src/transformers/__init__.py
+1
-0
src/transformers/modeling_auto.py
src/transformers/modeling_auto.py
+2
-0
src/transformers/modeling_longformer.py
src/transformers/modeling_longformer.py
+99
-0
tests/test_modeling_longformer.py
tests/test_modeling_longformer.py
+20
-0
No files found.
src/transformers/__init__.py
View file @
e444648a
...
@@ -326,6 +326,7 @@ if is_torch_available():
...
@@ -326,6 +326,7 @@ if is_torch_available():
LongformerModel
,
LongformerModel
,
LongformerForMaskedLM
,
LongformerForMaskedLM
,
LongformerForSequenceClassification
,
LongformerForSequenceClassification
,
LongformerForTokenClassification
,
LongformerForQuestionAnswering
,
LongformerForQuestionAnswering
,
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
,
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
,
)
)
...
...
src/transformers/modeling_auto.py
View file @
e444648a
...
@@ -106,6 +106,7 @@ from .modeling_longformer import (
...
@@ -106,6 +106,7 @@ from .modeling_longformer import (
LongformerForMaskedLM
,
LongformerForMaskedLM
,
LongformerForQuestionAnswering
,
LongformerForQuestionAnswering
,
LongformerForSequenceClassification
,
LongformerForSequenceClassification
,
LongformerForTokenClassification
,
LongformerModel
,
LongformerModel
,
)
)
from
.modeling_marian
import
MarianMTModel
from
.modeling_marian
import
MarianMTModel
...
@@ -282,6 +283,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
...
@@ -282,6 +283,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
(
CamembertConfig
,
CamembertForTokenClassification
),
(
CamembertConfig
,
CamembertForTokenClassification
),
(
XLMConfig
,
XLMForTokenClassification
),
(
XLMConfig
,
XLMForTokenClassification
),
(
XLMRobertaConfig
,
XLMRobertaForTokenClassification
),
(
XLMRobertaConfig
,
XLMRobertaForTokenClassification
),
(
LongformerConfig
,
LongformerForTokenClassification
),
(
RobertaConfig
,
RobertaForTokenClassification
),
(
RobertaConfig
,
RobertaForTokenClassification
),
(
BertConfig
,
BertForTokenClassification
),
(
BertConfig
,
BertForTokenClassification
),
(
XLNetConfig
,
XLNetForTokenClassification
),
(
XLNetConfig
,
XLNetForTokenClassification
),
...
...
src/transformers/modeling_longformer.py
View file @
e444648a
...
@@ -971,3 +971,102 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
...
@@ -971,3 +971,102 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
outputs
=
(
total_loss
,)
+
outputs
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# (loss), start_logits, end_logits, (hidden_states), (attentions)
return
outputs
# (loss), start_logits, end_logits, (hidden_states), (attentions)
@
add_start_docstrings
(
"""Longformer Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """
,
LONGFORMER_START_DOCSTRING
,
)
class
LongformerForTokenClassification
(
BertPreTrainedModel
):
config_class
=
LongformerConfig
pretrained_model_archive_map
=
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix
=
"longformer"
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
longformer
=
LongformerModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
init_weights
()
@
add_start_docstrings_to_callable
(
LONGFORMER_INPUTS_DOCSTRING
)
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
):
r
"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.LongformerConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import LongformerTokenizer, LongformerForTokenClassification
import torch
tokenizer = LongformerTokenizer.from_pretrained('longformer-base-4096')
model = LongformerForTokenClassification.from_pretrained('longformer-base-4096')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
outputs
=
self
.
longformer
(
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
inputs_embeds
=
inputs_embeds
,
)
sequence_output
=
outputs
[
0
]
sequence_output
=
self
.
dropout
(
sequence_output
)
logits
=
self
.
classifier
(
sequence_output
)
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
# Only keep active parts of the loss
if
attention_mask
is
not
None
:
active_loss
=
attention_mask
.
view
(
-
1
)
==
1
active_logits
=
logits
.
view
(
-
1
,
self
.
num_labels
)
active_labels
=
torch
.
where
(
active_loss
,
labels
.
view
(
-
1
),
torch
.
tensor
(
loss_fct
.
ignore_index
).
type_as
(
labels
)
)
loss
=
loss_fct
(
active_logits
,
active_labels
)
else
:
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), scores, (hidden_states), (attentions)
tests/test_modeling_longformer.py
View file @
e444648a
...
@@ -30,6 +30,7 @@ if is_torch_available():
...
@@ -30,6 +30,7 @@ if is_torch_available():
LongformerModel
,
LongformerModel
,
LongformerForMaskedLM
,
LongformerForMaskedLM
,
LongformerForSequenceClassification
,
LongformerForSequenceClassification
,
LongformerForTokenClassification
,
LongformerForQuestionAnswering
,
LongformerForQuestionAnswering
,
)
)
...
@@ -212,6 +213,21 @@ class LongformerModelTester(object):
...
@@ -212,6 +213,21 @@ class LongformerModelTester(object):
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
def
create_and_check_longformer_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_labels
=
self
.
num_labels
model
=
LongformerForTokenClassification
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
(
...
@@ -278,6 +294,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -278,6 +294,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_longformer_for_sequence_classification
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_longformer_for_sequence_classification
(
*
config_and_inputs
)
def
test_for_token_classification
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_longformer_for_token_classification
(
*
config_and_inputs
)
class
LongformerModelIntegrationTest
(
unittest
.
TestCase
):
class
LongformerModelIntegrationTest
(
unittest
.
TestCase
):
@
slow
@
slow
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment