Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6637a77f
"docs/vscode:/vscode.git/clone" did not exist on "d22894dfd40d5c858e8398e2783545103d191b47"
Commit
6637a77f
authored
Nov 01, 2019
by
Lysandre
Committed by
Lysandre Debut
Nov 26, 2019
Browse files
AlbertForSequenceClassification
parent
0d07a23c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
77 additions
and
3 deletions
+77
-3
transformers/__init__.py
transformers/__init__.py
+2
-1
transformers/modeling_albert.py
transformers/modeling_albert.py
+75
-2
No files found.
transformers/__init__.py
View file @
6637a77f
...
...
@@ -107,7 +107,8 @@ if is_torch_available():
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_encoder_decoder
import
PreTrainedEncoderDecoder
,
Model2Model
from
.modeling_albert
import
(
AlbertModel
,
AlbertForMaskedLM
,
load_tf_weights_in_albert
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_albert
import
(
AlbertModel
,
AlbertForMaskedLM
,
AlbertForSequenceClassification
,
load_tf_weights_in_albert
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
# Optimization
from
.optimization
import
(
AdamW
,
get_constant_schedule
,
get_constant_schedule_with_warmup
,
get_cosine_schedule_with_warmup
,
...
...
transformers/modeling_albert.py
View file @
6637a77f
...
...
@@ -20,10 +20,10 @@ import math
import
logging
import
torch
import
torch.nn
as
nn
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.configuration_albert
import
AlbertConfig
from
transformers.modeling_bert
import
BertEmbeddings
,
BertPreTrainedModel
,
BertModel
,
BertSelfAttention
,
prune_linear_layer
,
ACT2FN
from
transformers.modeling_bert
import
BertEmbeddings
,
BertSelfAttention
,
prune_linear_layer
,
ACT2FN
from
.file_utils
import
add_start_docstrings
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -510,3 +510,76 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
outputs
=
(
masked_lm_loss
,)
+
outputs
return
outputs
@
add_start_docstrings
(
"""Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """
,
ALBERT_START_DOCSTRING
,
ALBERT_INPUTS_DOCSTRING
)
class
AlbertForSequenceClassification
(
AlbertPreTrainedModel
):
r
"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification (or regression if config.num_labels==1) loss.
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = AlbertTokenizer.from_pretrained('albert-base')
model = AlbertForSequenceClassification.from_pretrained('albert-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
def
__init__
(
self
,
config
):
super
(
AlbertForSequenceClassification
,
self
).
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
albert
=
AlbertModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
self
.
config
.
num_labels
)
self
.
init_weights
()
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
labels
=
None
):
outputs
=
self
.
albert
(
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
head_mask
=
head_mask
)
pooled_output
=
outputs
[
1
]
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
# We are doing regression
loss_fct
=
MSELoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
),
labels
.
view
(
-
1
))
else
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), logits, (hidden_states), (attentions)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment