Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
467e9158
Unverified
Commit
467e9158
authored
Dec 18, 2020
by
sandip
Committed by
GitHub
Dec 17, 2020
Browse files
Added TF CTRL Sequence Classification (#9151)
* Added TF CTRL Sequence Classification * code refactor
parent
63841c55
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
200 additions
and
3 deletions
+200
-3
docs/source/model_doc/ctrl.rst
docs/source/model_doc/ctrl.rst
+5
-0
src/transformers/__init__.py
src/transformers/__init__.py
+1
-0
src/transformers/models/auto/modeling_tf_auto.py
src/transformers/models/auto/modeling_tf_auto.py
+2
-1
src/transformers/models/ctrl/__init__.py
src/transformers/models/ctrl/__init__.py
+1
-0
src/transformers/models/ctrl/modeling_tf_ctrl.py
src/transformers/models/ctrl/modeling_tf_ctrl.py
+160
-1
src/transformers/utils/dummy_tf_objects.py
src/transformers/utils/dummy_tf_objects.py
+9
-0
tests/test_modeling_tf_ctrl.py
tests/test_modeling_tf_ctrl.py
+22
-1
No files found.
docs/source/model_doc/ctrl.rst
View file @
467e9158
...
...
@@ -97,3 +97,8 @@ TFCTRLLMHeadModel
.. autoclass:: transformers.TFCTRLLMHeadModel
:members: call
TFCTRLForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFCTRLForSequenceClassification
:members: call
src/transformers/__init__.py
View file @
467e9158
...
...
@@ -756,6 +756,7 @@ if is_tf_available():
)
from
.models.ctrl
import
(
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST
,
TFCTRLForSequenceClassification
,
TFCTRLLMHeadModel
,
TFCTRLModel
,
TFCTRLPreTrainedModel
,
...
...
src/transformers/models/auto/modeling_tf_auto.py
View file @
467e9158
...
...
@@ -53,7 +53,7 @@ from ..camembert.modeling_tf_camembert import (
TFCamembertForTokenClassification
,
TFCamembertModel
,
)
from
..ctrl.modeling_tf_ctrl
import
TFCTRLLMHeadModel
,
TFCTRLModel
from
..ctrl.modeling_tf_ctrl
import
TFCTRLForSequenceClassification
,
TFCTRLLMHeadModel
,
TFCTRLModel
from
..distilbert.modeling_tf_distilbert
import
(
TFDistilBertForMaskedLM
,
TFDistilBertForMultipleChoice
,
...
...
@@ -342,6 +342,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(
GPT2Config
,
TFGPT2ForSequenceClassification
),
(
MPNetConfig
,
TFMPNetForSequenceClassification
),
(
OpenAIGPTConfig
,
TFOpenAIGPTForSequenceClassification
),
(
CTRLConfig
,
TFCTRLForSequenceClassification
),
]
)
...
...
src/transformers/models/ctrl/__init__.py
View file @
467e9158
...
...
@@ -33,6 +33,7 @@ if is_torch_available():
if
is_tf_available
():
from
.modeling_tf_ctrl
import
(
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST
,
TFCTRLForSequenceClassification
,
TFCTRLLMHeadModel
,
TFCTRLModel
,
TFCTRLPreTrainedModel
,
...
...
src/transformers/models/ctrl/modeling_tf_ctrl.py
View file @
467e9158
...
...
@@ -19,11 +19,13 @@ import numpy as np
import
tensorflow
as
tf
from
...file_utils
import
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
from
...modeling_tf_outputs
import
TFBaseModelOutputWithPast
,
TFCausalLMOutputWithPast
from
...modeling_tf_outputs
import
TFBaseModelOutputWithPast
,
TFCausalLMOutputWithPast
,
TFSequenceClassifierOutput
from
...modeling_tf_utils
import
(
TFCausalLanguageModelingLoss
,
TFPreTrainedModel
,
TFSequenceClassificationLoss
,
TFSharedEmbeddings
,
get_initializer
,
input_processing
,
keras_serializable
,
shape_list
,
...
...
@@ -726,3 +728,160 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
hidden_states
=
transformer_outputs
.
hidden_states
,
attentions
=
transformer_outputs
.
attentions
,
)
@
add_start_docstrings
(
"""
The CTRL Model transformer with a sequence classification head on top (linear layer).
:class:`~transformers.TFCTRLForSequenceClassification` uses the last token in order to do the classification, as
other causal models (e.g. GPT-1, GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
:obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take
the last value in each row of the batch).
"""
,
CTRL_START_DOCSTRING
,
)
class
TFCTRLForSequenceClassification
(
TFCTRLPreTrainedModel
,
TFSequenceClassificationLoss
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
().
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
num_labels
=
config
.
num_labels
self
.
classifier
=
tf
.
keras
.
layers
.
Dense
(
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"classifier"
,
use_bias
=
False
,
)
self
.
transformer
=
TFCTRLMainLayer
(
config
,
name
=
"transformer"
)
def
get_output_embeddings
(
self
):
return
self
.
transformer
.
w
@
add_start_docstrings_to_model_forward
(
CTRL_INPUTS_DOCSTRING
)
@
add_code_sample_docstrings
(
tokenizer_class
=
_TOKENIZER_FOR_DOC
,
checkpoint
=
"ctrl"
,
output_type
=
TFSequenceClassifierOutput
,
config_class
=
_CONFIG_FOR_DOC
,
)
def
call
(
self
,
input_ids
=
None
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
use_cache
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
return_dict
=
None
,
labels
=
None
,
training
=
False
,
**
kwargs
,
):
r
"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``.
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
past
=
past
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
labels
=
labels
,
training
=
training
,
kwargs_call
=
kwargs
,
)
transformer_outputs
=
self
.
transformer
(
input_ids
=
inputs
[
"input_ids"
],
past
=
inputs
[
"past"
],
attention_mask
=
inputs
[
"attention_mask"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
position_ids
=
inputs
[
"position_ids"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
use_cache
=
inputs
[
"use_cache"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
hidden_states
=
transformer_outputs
[
0
]
logits
=
self
.
classifier
(
hidden_states
)
logits_shape
=
shape_list
(
logits
)
in_logits
=
None
if
self
.
config
.
pad_token_id
is
None
:
sequence_lengths
=
-
1
else
:
if
inputs
[
"input_ids"
]
is
not
None
:
sequence_lengths
=
(
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
not_equal
(
inputs
[
"input_ids"
],
self
.
config
.
pad_token_id
),
tf
.
int32
),
-
1
,
keepdims
=
False
,
)
-
1
)
def
get_seq_element
(
sequence_position
,
input_batch
):
return
tf
.
strided_slice
(
input_batch
,
[
sequence_position
,
0
],
[
sequence_position
+
1
,
input_batch
.
shape
[
-
1
]],
[
1
,
1
]
)
result
=
tf
.
map_fn
(
fn
=
lambda
t
:
get_seq_element
(
t
[
0
],
t
[
1
]),
elems
=
[
sequence_lengths
,
logits
],
dtype
=
"float"
)
in_logits
=
tf
.
reshape
(
result
,
[
logits_shape
[
0
],
logits_shape
[
-
1
]])
else
:
sequence_lengths
=
-
1
logger
.
warning
(
f
"
{
self
.
__class__
.
__name__
}
will not detect padding tokens in `inputs_embeds`. Results may be "
f
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
loss
=
None
if
inputs
[
"labels"
]
is
not
None
:
if
input_ids
is
not
None
:
batch_size
,
sequence_length
=
shape_list
(
inputs
[
"input_ids"
])[:
2
]
else
:
batch_size
,
sequence_length
=
shape_list
(
inputs
[
"inputs_embeds"
])[:
2
]
assert
(
self
.
config
.
pad_token_id
is
not
None
or
batch_size
==
1
),
"Cannot handle batch sizes > 1 if no padding token is defined."
if
not
tf
.
is_tensor
(
sequence_lengths
):
in_logits
=
logits
[
0
:
batch_size
,
sequence_lengths
]
loss
=
self
.
compute_loss
(
tf
.
reshape
(
inputs
[
"labels"
],
[
-
1
,
1
]),
tf
.
reshape
(
in_logits
,
[
-
1
,
self
.
num_labels
])
)
pooled_logits
=
in_logits
if
in_logits
is
not
None
else
logits
if
not
inputs
[
"return_dict"
]:
output
=
(
pooled_logits
,)
+
transformer_outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
TFSequenceClassifierOutput
(
loss
=
loss
,
logits
=
pooled_logits
,
hidden_states
=
transformer_outputs
.
hidden_states
,
attentions
=
transformer_outputs
.
attentions
,
)
src/transformers/utils/dummy_tf_objects.py
View file @
467e9158
...
...
@@ -429,6 +429,15 @@ class TFCamembertModel:
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST
=
None
class
TFCTRLForSequenceClassification
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_tf
(
self
)
@
classmethod
def
from_pretrained
(
self
,
*
args
,
**
kwargs
):
requires_tf
(
self
)
class
TFCTRLLMHeadModel
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_tf
(
self
)
...
...
tests/test_modeling_tf_ctrl.py
View file @
467e9158
...
...
@@ -28,6 +28,7 @@ if is_tf_available():
from
transformers.models.ctrl.modeling_tf_ctrl
import
(
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST
,
TFCTRLForSequenceClassification
,
TFCTRLLMHeadModel
,
TFCTRLModel
,
)
...
...
@@ -61,6 +62,7 @@ class TFCTRLModelTester(object):
self
.
num_labels
=
3
self
.
num_choices
=
4
self
.
scope
=
None
self
.
pad_token_id
=
self
.
vocab_size
-
1
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
...
...
@@ -98,6 +100,7 @@ class TFCTRLModelTester(object):
n_ctx
=
self
.
max_position_embeddings
,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range,
pad_token_id
=
self
.
pad_token_id
,
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
...
...
@@ -132,6 +135,20 @@ class TFCTRLModelTester(object):
result
=
model
(
inputs
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
))
def
create_and_check_ctrl_for_sequence_classification
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
config
.
num_labels
=
self
.
num_labels
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
inputs
=
{
"input_ids"
:
input_ids
,
"token_type_ids"
:
token_type_ids
,
"labels"
:
sequence_labels
,
}
model
=
TFCTRLForSequenceClassification
(
config
)
result
=
model
(
inputs
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
num_labels
))
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
...
...
@@ -154,7 +171,7 @@ class TFCTRLModelTester(object):
@
require_tf
class
TFCTRLModelTest
(
TFModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
TFCTRLModel
,
TFCTRLLMHeadModel
)
if
is_tf_available
()
else
()
all_model_classes
=
(
TFCTRLModel
,
TFCTRLLMHeadModel
,
TFCTRLForSequenceClassification
)
if
is_tf_available
()
else
()
all_generative_model_classes
=
(
TFCTRLLMHeadModel
,)
if
is_tf_available
()
else
()
def
setUp
(
self
):
...
...
@@ -172,6 +189,10 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_ctrl_lm_head
(
*
config_and_inputs
)
def
test_ctrl_sequence_classification_model
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_ctrl_for_sequence_classification
(
*
config_and_inputs
)
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment