Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7bddbf59
Commit
7bddbf59
authored
Nov 07, 2019
by
Lysandre
Committed by
Lysandre Debut
Nov 26, 2019
Browse files
TFAlbertForSequenceClassification
parent
f6f38253
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
58 additions
and
8 deletions
+58
-8
transformers/modeling_tf_albert.py
transformers/modeling_tf_albert.py
+58
-8
No files found.
transformers/modeling_tf_albert.py
View file @
7bddbf59
...
@@ -479,16 +479,15 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
...
@@ -479,16 +479,15 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
ALBERT_START_DOCSTRING
=
r
""" The ALBERT model was proposed in
ALBERT_START_DOCSTRING
=
r
""" The ALBERT model was proposed in
`ALBERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
`ALBERT: A Lite BERT for Self-supervised Learning of Language Representations`_
by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer
by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut. It presents
pre-trained using a combination of masked language modeling objective and next sentence prediction
two parameter-reduction techniques to lower memory consumption and increase the trainig speed of BERT.
on a large corpus comprising the Toronto Book Corpus and Wikipedia.
This model is a tf.keras.Model `tf.keras.Model`_ sub-class. Use it as a regular TF 2.0 Keras Model and
This model is a tf.keras.Model `tf.keras.Model`_ sub-class. Use it as a regular TF 2.0 Keras Model and
refer to the TF 2.0 documentation for all matter related to general usage and behavior.
refer to the TF 2.0 documentation for all matter related to general usage and behavior.
.. _`ALBERT:
Pre-training of Deep Bidirectional Transformers for Language Understanding
`:
.. _`ALBERT:
A Lite BERT for Self-supervised Learning of Language Representations
`:
https://arxiv.org/abs/1
810.04805
https://arxiv.org/abs/1
909.11942
.. _`tf.keras.Model`:
.. _`tf.keras.Model`:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model
...
@@ -695,8 +694,8 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
...
@@ -695,8 +694,8 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
import tensorflow as tf
import tensorflow as tf
from transformers import AlbertTokenizer, TFAlbertForMaskedLM
from transformers import AlbertTokenizer, TFAlbertForMaskedLM
tokenizer = AlbertTokenizer.from_pretrained('bert-base-
uncased
')
tokenizer = AlbertTokenizer.from_pretrained('
al
bert-base-
v2
')
model = TFAlbertForMaskedLM.from_pretrained('bert-base-
uncased
')
model = TFAlbertForMaskedLM.from_pretrained('
al
bert-base-
v2
')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
outputs = model(input_ids)
prediction_scores = outputs[0]
prediction_scores = outputs[0]
...
@@ -721,3 +720,54 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
...
@@ -721,3 +720,54 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
outputs
=
(
prediction_scores
,)
+
outputs
[
2
:]
outputs
=
(
prediction_scores
,)
+
outputs
[
2
:]
return
outputs
# prediction_scores, (hidden_states), (attentions)
return
outputs
# prediction_scores, (hidden_states), (attentions)
@
add_start_docstrings
(
"""Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """
,
ALBERT_START_DOCSTRING
,
ALBERT_INPUTS_DOCSTRING
)
class
TFAlbertForSequenceClassification
(
TFAlbertPreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**logits**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import AlbertTokenizer, TFAlbertForSequenceClassification
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = TFAlbertForSequenceClassification.from_pretrained('albert-base-v2')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
logits = outputs[0]
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFAlbertForSequenceClassification
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
num_labels
=
config
.
num_labels
self
.
albert
=
TFAlbertModel
(
config
,
name
=
'albert'
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
tf
.
keras
.
layers
.
Dense
(
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
'classifier'
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
albert
(
inputs
,
**
kwargs
)
pooled_output
=
outputs
[
1
]
pooled_output
=
self
.
dropout
(
pooled_output
,
training
=
kwargs
.
get
(
'training'
,
False
))
logits
=
self
.
classifier
(
pooled_output
)
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
return
outputs
# logits, (hidden_states), (attentions)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment