Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4374eaea
Commit
4374eaea
authored
Nov 06, 2019
by
Lysandre
Committed by
Lysandre Debut
Nov 26, 2019
Browse files
ALBERT for SQuAD
parent
70d99980
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
100 additions
and
6 deletions
+100
-6
examples/run_squad.py
examples/run_squad.py
+7
-5
transformers/__init__.py
transformers/__init__.py
+1
-0
transformers/modeling_albert.py
transformers/modeling_albert.py
+92
-1
No files found.
examples/run_squad.py
View file @
4374eaea
...
@@ -43,7 +43,8 @@ from transformers import (WEIGHTS_NAME, BertConfig,
...
@@ -43,7 +43,8 @@ from transformers import (WEIGHTS_NAME, BertConfig,
XLMTokenizer
,
XLNetConfig
,
XLMTokenizer
,
XLNetConfig
,
XLNetForQuestionAnswering
,
XLNetForQuestionAnswering
,
XLNetTokenizer
,
XLNetTokenizer
,
DistilBertConfig
,
DistilBertForQuestionAnswering
,
DistilBertTokenizer
)
DistilBertConfig
,
DistilBertForQuestionAnswering
,
DistilBertTokenizer
,
AlbertConfig
,
AlbertForQuestionAnswering
,
AlbertTokenizer
)
from
transformers
import
AdamW
,
get_linear_schedule_with_warmup
from
transformers
import
AdamW
,
get_linear_schedule_with_warmup
...
@@ -65,7 +66,8 @@ MODEL_CLASSES = {
...
@@ -65,7 +66,8 @@ MODEL_CLASSES = {
'bert'
:
(
BertConfig
,
BertForQuestionAnswering
,
BertTokenizer
),
'bert'
:
(
BertConfig
,
BertForQuestionAnswering
,
BertTokenizer
),
'xlnet'
:
(
XLNetConfig
,
XLNetForQuestionAnswering
,
XLNetTokenizer
),
'xlnet'
:
(
XLNetConfig
,
XLNetForQuestionAnswering
,
XLNetTokenizer
),
'xlm'
:
(
XLMConfig
,
XLMForQuestionAnswering
,
XLMTokenizer
),
'xlm'
:
(
XLMConfig
,
XLMForQuestionAnswering
,
XLMTokenizer
),
'distilbert'
:
(
DistilBertConfig
,
DistilBertForQuestionAnswering
,
DistilBertTokenizer
)
'distilbert'
:
(
DistilBertConfig
,
DistilBertForQuestionAnswering
,
DistilBertTokenizer
),
'albert'
:
(
AlbertConfig
,
AlbertForQuestionAnswering
,
AlbertTokenizer
)
}
}
def
set_seed
(
args
):
def
set_seed
(
args
):
...
@@ -128,7 +130,7 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -128,7 +130,7 @@ def train(args, train_dataset, model, tokenizer):
logger
.
info
(
" Gradient Accumulation steps = %d"
,
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Gradient Accumulation steps = %d"
,
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
global_step
=
0
global_step
=
1
tr_loss
,
logging_loss
=
0.0
,
0.0
tr_loss
,
logging_loss
=
0.0
,
0.0
model
.
zero_grad
()
model
.
zero_grad
()
train_iterator
=
trange
(
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
train_iterator
=
trange
(
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
...
@@ -537,7 +539,7 @@ def main():
...
@@ -537,7 +539,7 @@ def main():
torch
.
save
(
args
,
os
.
path
.
join
(
args
.
output_dir
,
'training_args.bin'
))
torch
.
save
(
args
,
os
.
path
.
join
(
args
.
output_dir
,
'training_args.bin'
))
# Load a trained model and vocabulary that you have fine-tuned
# Load a trained model and vocabulary that you have fine-tuned
model
=
model_class
.
from_pretrained
(
args
.
output_dir
)
model
=
model_class
.
from_pretrained
(
args
.
output_dir
,
force_download
=
True
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
...
@@ -555,7 +557,7 @@ def main():
...
@@ -555,7 +557,7 @@ def main():
for
checkpoint
in
checkpoints
:
for
checkpoint
in
checkpoints
:
# Reload the model
# Reload the model
global_step
=
checkpoint
.
split
(
'-'
)[
-
1
]
if
len
(
checkpoints
)
>
1
else
""
global_step
=
checkpoint
.
split
(
'-'
)[
-
1
]
if
len
(
checkpoints
)
>
1
else
""
model
=
model_class
.
from_pretrained
(
checkpoint
)
model
=
model_class
.
from_pretrained
(
checkpoint
,
force_download
=
True
)
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
# Evaluate
# Evaluate
...
...
transformers/__init__.py
View file @
4374eaea
...
@@ -108,6 +108,7 @@ if is_torch_available():
...
@@ -108,6 +108,7 @@ if is_torch_available():
from
.modeling_encoder_decoder
import
PreTrainedEncoderDecoder
,
Model2Model
from
.modeling_encoder_decoder
import
PreTrainedEncoderDecoder
,
Model2Model
from
.modeling_albert
import
(
AlbertModel
,
AlbertForMaskedLM
,
AlbertForSequenceClassification
,
from
.modeling_albert
import
(
AlbertModel
,
AlbertForMaskedLM
,
AlbertForSequenceClassification
,
AlbertForQuestionAnswering
,
load_tf_weights_in_albert
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
load_tf_weights_in_albert
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
# Optimization
# Optimization
...
...
transformers/modeling_albert.py
View file @
4374eaea
...
@@ -586,4 +586,95 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
...
@@ -586,4 +586,95 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
(
loss
,)
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), logits, (hidden_states), (attentions)
return
outputs
# (loss), logits, (hidden_states), (attentions)
\ No newline at end of file
@
add_start_docstrings
(
"""Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """
,
ALBERT_START_DOCSTRING
,
ALBERT_INPUTS_DOCSTRING
)
class
AlbertForQuestionAnswering
(
AlbertPreTrainedModel
):
r
"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = AlbertForQuestionAnswering.from_pretrained('albert-base-v2')
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
input_text = "[CLS] " + question + " [SEP] " + text + " [SEP]"
input_ids = tokenizer.encode(input_text)
token_type_ids = [0 if i <= input_ids.index(102) else 1 for i in range(len(input_ids))]
start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([token_type_ids]))
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
print(' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))
# a nice puppet
"""
def
__init__
(
self
,
config
):
super
(
AlbertForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
albert
=
AlbertModel
(
config
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
init_weights
()
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
start_positions
=
None
,
end_positions
=
None
):
outputs
=
self
.
albert
(
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
head_mask
=
head_mask
)
sequence_output
=
outputs
[
0
]
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
outputs
=
(
start_logits
,
end_logits
,)
+
outputs
[
2
:]
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
if
len
(
start_positions
.
size
())
>
1
:
start_positions
=
start_positions
.
squeeze
(
-
1
)
if
len
(
end_positions
.
size
())
>
1
:
end_positions
=
end_positions
.
squeeze
(
-
1
)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index
=
start_logits
.
size
(
1
)
start_positions
.
clamp_
(
0
,
ignored_index
)
end_positions
.
clamp_
(
0
,
ignored_index
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=
ignored_index
)
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# (loss), start_logits, end_logits, (hidden_states), (attentions)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment