Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
66085a13
Commit
66085a13
authored
Oct 23, 2019
by
Matt Maybeno
Committed by
Julien Chaumond
Oct 24, 2019
Browse files
RoBERTa token classification
[WIP] copy paste bert token classification for roberta
parent
5b6cafb1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
158 additions
and
1 deletion
+158
-1
transformers/__init__.py
transformers/__init__.py
+2
-0
transformers/modeling_roberta.py
transformers/modeling_roberta.py
+72
-0
transformers/modeling_tf_roberta.py
transformers/modeling_tf_roberta.py
+51
-0
transformers/tests/modeling_roberta_test.py
transformers/tests/modeling_roberta_test.py
+18
-1
transformers/tests/modeling_tf_roberta_test.py
transformers/tests/modeling_tf_roberta_test.py
+15
-0
No files found.
transformers/__init__.py
View file @
66085a13
...
@@ -89,6 +89,7 @@ if is_torch_available():
...
@@ -89,6 +89,7 @@ if is_torch_available():
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
)
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_roberta
import
(
RobertaForMaskedLM
,
RobertaModel
,
from
.modeling_roberta
import
(
RobertaForMaskedLM
,
RobertaModel
,
RobertaForSequenceClassification
,
RobertaForMultipleChoice
,
RobertaForSequenceClassification
,
RobertaForMultipleChoice
,
RobertaForTokenClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
)
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_distilbert
import
(
DistilBertForMaskedLM
,
DistilBertModel
,
from
.modeling_distilbert
import
(
DistilBertForMaskedLM
,
DistilBertModel
,
DistilBertForSequenceClassification
,
DistilBertForQuestionAnswering
,
DistilBertForSequenceClassification
,
DistilBertForQuestionAnswering
,
...
@@ -139,6 +140,7 @@ if is_tf_available():
...
@@ -139,6 +140,7 @@ if is_tf_available():
from
.modeling_tf_roberta
import
(
TFRobertaPreTrainedModel
,
TFRobertaMainLayer
,
from
.modeling_tf_roberta
import
(
TFRobertaPreTrainedModel
,
TFRobertaMainLayer
,
TFRobertaModel
,
TFRobertaForMaskedLM
,
TFRobertaModel
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
TFRobertaForSequenceClassification
,
TFRobertaForTokenClassification
,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_distilbert
import
(
TFDistilBertPreTrainedModel
,
TFDistilBertMainLayer
,
from
.modeling_tf_distilbert
import
(
TFDistilBertPreTrainedModel
,
TFDistilBertMainLayer
,
...
...
transformers/modeling_roberta.py
View file @
66085a13
...
@@ -343,6 +343,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
...
@@ -343,6 +343,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
return
outputs
# (loss), logits, (hidden_states), (attentions)
return
outputs
# (loss), logits, (hidden_states), (attentions)
@
add_start_docstrings
(
"""Roberta Model with a multiple choice classification head on top (a linear layer on top of
@
add_start_docstrings
(
"""Roberta Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """
,
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """
,
ROBERTA_START_DOCSTRING
,
ROBERTA_INPUTS_DOCSTRING
)
ROBERTA_START_DOCSTRING
,
ROBERTA_INPUTS_DOCSTRING
)
...
@@ -451,6 +452,77 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
...
@@ -451,6 +452,77 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
return
outputs
# (loss), reshaped_logits, (hidden_states), (attentions)
return
outputs
# (loss), reshaped_logits, (hidden_states), (attentions)
@
add_start_docstrings
(
"""Roberta Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """
,
ROBERTA_START_DOCSTRING
,
ROBERTA_INPUTS_DOCSTRING
)
class
RobertaForTokenClassification
(
BertPreTrainedModel
):
r
"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForTokenClassification.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
def
__init__
(
self
,
config
):
super
(
RobertaForTokenClassification
,
self
).
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
roberta
=
RobertaModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
init_weights
()
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
labels
=
None
):
outputs
=
self
.
roberta
(
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
head_mask
=
head_mask
)
sequence_output
=
outputs
[
0
]
sequence_output
=
self
.
dropout
(
sequence_output
)
logits
=
self
.
classifier
(
sequence_output
)
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
# Only keep active parts of the loss
if
attention_mask
is
not
None
:
active_loss
=
attention_mask
.
view
(
-
1
)
==
1
active_logits
=
logits
.
view
(
-
1
,
self
.
num_labels
)[
active_loss
]
active_labels
=
labels
.
view
(
-
1
)[
active_loss
]
loss
=
loss_fct
(
active_logits
,
active_labels
)
else
:
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), scores, (hidden_states), (attentions)
class
RobertaClassificationHead
(
nn
.
Module
):
class
RobertaClassificationHead
(
nn
.
Module
):
"""Head for sentence-level classification tasks."""
"""Head for sentence-level classification tasks."""
...
...
transformers/modeling_tf_roberta.py
View file @
66085a13
...
@@ -371,3 +371,54 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel):
...
@@ -371,3 +371,54 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel):
outputs
=
(
logits
,)
+
outputs
[
2
:]
outputs
=
(
logits
,)
+
outputs
[
2
:]
return
outputs
# logits, (hidden_states), (attentions)
return
outputs
# logits, (hidden_states), (attentions)
@
add_start_docstrings
(
"""RoBERTa Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """
,
ROBERTA_START_DOCSTRING
,
ROBERTA_INPUTS_DOCSTRING
)
class
TFRobertaForTokenClassification
(
TFRobertaPreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import RobertaTokenizer, TFRobertaForTokenClassification
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = TFRobertaForTokenClassification.from_pretrained('roberta-base')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
scores = outputs[0]
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFRobertaForTokenClassification
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
num_labels
=
config
.
num_labels
self
.
roberta
=
TFRobertaMainLayer
(
config
,
name
=
'roberta'
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
tf
.
keras
.
layers
.
Dense
(
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
'classifier'
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
roberta
(
inputs
,
**
kwargs
)
sequence_output
=
outputs
[
0
]
sequence_output
=
self
.
dropout
(
sequence_output
,
training
=
kwargs
.
get
(
'training'
,
False
))
logits
=
self
.
classifier
(
sequence_output
)
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
return
outputs
# scores, (hidden_states), (attentions)
transformers/tests/modeling_roberta_test.py
View file @
66085a13
...
@@ -24,7 +24,8 @@ from transformers import is_torch_available
...
@@ -24,7 +24,8 @@ from transformers import is_torch_available
if
is_torch_available
():
if
is_torch_available
():
import
torch
import
torch
from
transformers
import
(
RobertaConfig
,
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
)
from
transformers
import
(
RobertaConfig
,
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
RobertaForTokenClassification
)
from
transformers.modeling_roberta
import
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
from
transformers.modeling_roberta
import
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
else
:
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require Torch"
)
pytestmark
=
pytest
.
mark
.
skip
(
"Require Torch"
)
...
@@ -156,6 +157,22 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
...
@@ -156,6 +157,22 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
def
create_and_check_roberta_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_labels
=
self
.
num_labels
model
=
RobertaForTokenClassification
(
config
=
config
)
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
token_type_ids
,
input_mask
,
(
config
,
input_ids
,
token_type_ids
,
input_mask
,
...
...
transformers/tests/modeling_tf_roberta_test.py
View file @
66085a13
...
@@ -30,6 +30,7 @@ if is_tf_available():
...
@@ -30,6 +30,7 @@ if is_tf_available():
import
numpy
import
numpy
from
transformers.modeling_tf_roberta
import
(
TFRobertaModel
,
TFRobertaForMaskedLM
,
from
transformers.modeling_tf_roberta
import
(
TFRobertaModel
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
TFRobertaForSequenceClassification
,
TFRobertaForTokenClassification
,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
)
else
:
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require TensorFlow"
)
pytestmark
=
pytest
.
mark
.
skip
(
"Require TensorFlow"
)
...
@@ -154,6 +155,20 @@ class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -154,6 +155,20 @@ class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester):
list
(
result
[
"prediction_scores"
].
shape
),
list
(
result
[
"prediction_scores"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_and_check_roberta_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_labels
=
self
.
num_labels
model
=
TFRobertaForTokenClassification
(
config
=
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
logits
,
=
model
(
inputs
)
result
=
{
"logits"
:
logits
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
token_type_ids
,
input_mask
,
(
config
,
input_ids
,
token_type_ids
,
input_mask
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment