Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4c3ac4a7
Commit
4c3ac4a7
authored
Oct 18, 2019
by
Rémi Louf
Browse files
here's one big commit
parent
932543f7
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
951 additions
and
47 deletions
+951
-47
examples/README.md
examples/README.md
+3
-2
examples/run_summarization_finetuning.py
examples/run_summarization_finetuning.py
+620
-0
examples/run_summarization_finetuning_test.py
examples/run_summarization_finetuning_test.py
+14
-14
transformers/__init__.py
transformers/__init__.py
+1
-1
transformers/modeling_beam_search.py
transformers/modeling_beam_search.py
+240
-0
transformers/modeling_bert.py
transformers/modeling_bert.py
+12
-8
transformers/modeling_seq2seq.py
transformers/modeling_seq2seq.py
+61
-22
No files found.
examples/README.md
View file @
4c3ac4a7
...
@@ -393,7 +393,8 @@ This fine-tuned model is available as a checkpoint under the reference
...
@@ -393,7 +393,8 @@ This fine-tuned model is available as a checkpoint under the reference
## Seq2seq model fine-tuning
## Seq2seq model fine-tuning
Based on the script
[
`run_seq2seq_finetuning.py`
](
https://github.com/huggingface/transformers/blob/master/examples/run_seq2seq_finetuning.py
)
.
Based on the script
[
`run_summarization_finetuning.py`
](
https://github.com/huggingface/transformers/blob/master/examples/run_summarization_finetuning.py
)
.
Before running this script you should download
**both**
CNN and Daily Mail
Before running this script you should download
**both**
CNN and Daily Mail
datasets from
[
Kyunghyun Cho's website
](
https://cs.nyu.edu/~kcho/DMQA/
)
(
the
datasets from
[
Kyunghyun Cho's website
](
https://cs.nyu.edu/~kcho/DMQA/
)
(
the
...
@@ -412,7 +413,7 @@ archive.
...
@@ -412,7 +413,7 @@ archive.
```
bash
```
bash
export
DATA_PATH
=
/path/to/dataset/
export
DATA_PATH
=
/path/to/dataset/
python run_s
eq2seq
_finetuning.py
\
python run_s
ummarization
_finetuning.py
\
--output_dir
=
output
\
--output_dir
=
output
\
--model_type
=
bert2bert
\
--model_type
=
bert2bert
\
--model_name_or_path
=
bert2bert
\
--model_name_or_path
=
bert2bert
\
...
...
examples/run_s
eq2seq
_finetuning.py
→
examples/run_s
ummarization
_finetuning.py
View file @
4c3ac4a7
This diff is collapsed.
Click to expand it.
examples/run_s
eq2seq
_finetuning_test.py
→
examples/run_s
ummarization
_finetuning_test.py
View file @
4c3ac4a7
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
import
unittest
import
unittest
from
run_s
eq2seq
_finetuning
import
_fit_to_block_size
,
process_story
from
run_s
ummarization
_finetuning
import
_fit_to_block_size
,
process_story
class
DataLoaderTest
(
unittest
.
TestCase
):
class
DataLoaderTest
(
unittest
.
TestCase
):
...
@@ -43,15 +43,16 @@ class DataLoaderTest(unittest.TestCase):
...
@@ -43,15 +43,16 @@ class DataLoaderTest(unittest.TestCase):
raw_story
=
"""It was the year of Our Lord one thousand seven hundred and
raw_story
=
"""It was the year of Our Lord one thousand seven hundred and
seventy-five.
\n\n
Spiritual revelations were conceded to England at that
seventy-five.
\n\n
Spiritual revelations were conceded to England at that
favoured period, as at this."""
favoured period, as at this."""
with
self
.
assertRaises
(
IndexError
):
_
,
summary
=
process_story
(
raw_story
)
process_story
(
raw_story
)
self
.
assertEqual
(
summary
,
[]
)
def
test_process_empty_story
(
self
):
def
test_process_empty_story
(
self
):
""" An empty story should also raise and exception.
""" An empty story should also raise and exception.
"""
"""
raw_story
=
""
raw_story
=
""
with
self
.
assertRaises
(
IndexError
):
story
,
summary
=
process_story
(
raw_story
)
process_story
(
raw_story
)
self
.
assertEqual
(
story
,
[])
self
.
assertEqual
(
summary
,
[])
def
test_story_with_missing_period
(
self
):
def
test_story_with_missing_period
(
self
):
raw_story
=
(
raw_story
=
(
...
@@ -59,17 +60,16 @@ class DataLoaderTest(unittest.TestCase):
...
@@ -59,17 +60,16 @@ class DataLoaderTest(unittest.TestCase):
"seventy-five
\n\n
Spiritual revelations were conceded to England "
"seventy-five
\n\n
Spiritual revelations were conceded to England "
"at that favoured period, as at this.
\n
@highlight
\n\n
It was the best of times"
"at that favoured period, as at this.
\n
@highlight
\n\n
It was the best of times"
)
)
story
,
summary
=
process_story
(
raw_story
)
story
_lines
,
summary
_lines
=
process_story
(
raw_story
)
expected_story
=
(
expected_story_lines
=
[
"It was the year of Our Lord one thousand seven hundred and "
"It was the year of Our Lord one thousand seven hundred and seventy-five."
,
"seventy-five. Spiritual revelations were conceded to England at that "
"Spiritual revelations were conceded to England at that favoured period, as at this."
,
"favoured period, as at this."
]
)
self
.
assertEqual
(
expected_story_lines
,
story_lines
)
self
.
assertEqual
(
expected_story
,
story
)
expected_summary
=
"It was the best of times."
expected_summary
_lines
=
[
"It was the best of times."
]
self
.
assertEqual
(
expected_summary
,
summary
)
self
.
assertEqual
(
expected_summary
_lines
,
summary
_lines
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
transformers/__init__.py
View file @
4c3ac4a7
...
@@ -87,7 +87,7 @@ if is_torch_available():
...
@@ -87,7 +87,7 @@ if is_torch_available():
from
.modeling_distilbert
import
(
DistilBertForMaskedLM
,
DistilBertModel
,
from
.modeling_distilbert
import
(
DistilBertForMaskedLM
,
DistilBertModel
,
DistilBertForSequenceClassification
,
DistilBertForQuestionAnswering
,
DistilBertForSequenceClassification
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_seq2seq
import
Model2Model
from
.modeling_seq2seq
import
PreTrainedSeq2seq
,
Model2Model
# Optimization
# Optimization
from
.optimization
import
(
AdamW
,
ConstantLRSchedule
,
WarmupConstantSchedule
,
WarmupCosineSchedule
,
from
.optimization
import
(
AdamW
,
ConstantLRSchedule
,
WarmupConstantSchedule
,
WarmupCosineSchedule
,
...
...
transformers/modeling_beam_search.py
0 → 100644
View file @
4c3ac4a7
# coding=utf-8
# Copyright (c) 2019 Yang Liu
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
A general wrapper around models with LM heads to generate sequences
using beam search.
"""
import
torch
from
torch
import
nn
class
ModelWithBeamSearch
(
nn
.
Module
):
def
__init__
(
self
,
model
,
beam_size
,
start_token_id
,
end_token_id
,
pad_token_id
,
min_length
,
max_length
,
alpha
,
block_trigram
=
True
,
):
"""
Attributes:
mask_word_id: token id that corresponds to the mask
"""
super
(
ModelWithBeamSearch
,
self
).
__init__
()
self
.
model
=
model
self
.
beam_size
=
beam_size
self
.
start_token_id
=
start_token_id
self
.
end_token_id
=
end_token_id
self
.
pad_token_id
=
pad_token_id
self
.
min_length
=
min_length
self
.
max_length
=
max_length
self
.
alpha
=
alpha
self
.
block_trigram
=
block_trigram
def
forward
(
self
,
input_ids
,
**
kwargs
):
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
kwargs_encoder
=
{
argument
:
value
for
argument
,
value
in
kwargs
.
items
()
if
not
argument
.
startswith
(
"decoder_"
)
}
kwargs_decoder
=
{
argument
[
len
(
"decoder_"
):]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
}
batch_size
,
_
=
input_ids
.
size
(
0
)
# Variables that keep track of the status of the search
hypotheses
=
[[]
for
_
in
range
(
batch_size
)]
batch_offset
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
long
)
beam_offset
=
torch
.
arange
(
0
,
batch_size
*
self
.
beam_size
,
step
=
self
.
beam_size
,
dtype
=
torch
.
long
,
)
growing_beam
=
torch
.
full
(
(
batch_size
*
self
.
beam_size
,
1
),
self
.
start_token_id
,
dtype
=
torch
.
long
,
)
topk_log_probabilities
=
torch
.
tensor
(
[
0.0
]
+
[
float
(
"-inf"
)]
*
(
self
.
beam_size
-
1
),
dtype
=
torch
.
float
,
).
repeat
(
batch_size
)
# Forward pass on the encoder
encoder_outputs
=
self
.
encoder
(
input_ids
,
kwargs_encoder
)
kwargs_decoder
[
"encoder_hidden_states"
]
=
tile
(
encoder_outputs
,
self
.
beam_size
,
dim
=
0
)
results
=
{}
results
[
"predictions"
]
=
[[]
for
_
in
batch_size
]
results
[
"scores"
]
=
[[]
for
_
in
batch_size
]
for
step
in
range
(
self
.
max_length
):
decoder_input
=
growing_beam
[:,
-
1
]
outputs
=
self
.
decoder
(
decoder_input
,
kwargs_decoder
)
log_probabilities
=
torch
.
nn
.
functional
.
log_softmax
(
outputs
[
1
])
vocab_size
=
log_probabilities
.
size
(
-
1
)
# The batch size changes as some beams finish so we define:
_B
=
log_probabilities
.
size
(
0
)
//
self
.
beam_size
# Multiply each beam probability with the probability of the
# next token (conditioned on the words in the beam).
log_probabilities
+=
topk_log_probabilities
.
view
(
-
1
,
1
)
# if the beam has not attained the minimum required length we
# make the end token arbitrarily unlikely.
if
step
<
self
.
min_length
:
log_probabilities
[
self
.
end_token_id
]
=
-
1e20
# Remove repeating tri-grams
if
(
self
.
args
.
block_trigram
):
if
(
step
+
1
>
3
):
for
i
in
range
(
_B
*
self
.
beam_size
):
tokens
=
[
t
for
t
in
growing_beam
[
i
]]
trigrams
=
[(
tokens
[
i
-
1
],
tokens
[
i
],
tokens
[
i
+
1
])
for
i
in
range
(
1
,
len
(
words
)
-
1
)]
last_trigram
=
tuple
(
trigrams
[
-
1
])
if
last_trigram
in
trigrams
[:
-
1
]:
log_probabilities
[
i
]
=
-
1e20
# Find the `beam_size` (previous_beam + token) combinations with
# the highest score
topk_log_probabilities
,
topk_ids
=
log_probabilities
.
topk
(
log_probabilities
.
view
(
_B
,
self
.
beam_size
*
vocab_size
),
self
.
beam_size
,
dim
=
1
)
# Apply the length penalty. The +1 accounts for the [EOS] token
# that will be added if the beam ends.
length_penalty
=
((
5.0
+
(
step
+
1
))
/
6.0
)
**
self
.
alpha
topk_scores
=
topk_log_probabilities
/
length_penalty
# Retrieve the corresponding respective beam and token id
# topk_token_ids[i] will be added to topk_beam_ids[i]
topk_beam_ids
=
topk_ids
.
div
(
vocab_size
)
topk_token_ids
=
topk_ids
.
fmod
(
vocab_size
)
# Retrieve the row index of the surviving beams in the original
# view of the log_probabilities tensor
surviving_beams_rows
=
(
topk_beam_ids
+
beam_offset
[:
_B
].
view
(
-
1
,
1
)
).
view
(
-
1
)
# Append the last predictions
growing_beam
=
torch
.
cat
(
[
growing_beam
.
index_select
(
0
,
surviving_beams_rows
),
topk_token_ids
.
view
(
-
1
,
1
),
],
1
,
)
# Check if any of the beam searches has ended during this
# growth step. Also if top beam (most probable) has ended
# for one element of the batch.
is_finished
=
topk_token_ids
.
eq
(
self
.
end_token_id
)
if
step
+
1
==
self
.
max_length
:
is_finished
.
fill_
(
1
)
is_top_beam_finished
=
is_finished
[:,
0
].
eq
(
1
)
# Save the finished searches
if
is_finished
.
any
():
predictions
=
growing_beam
.
view
(
-
1
,
self
.
beam_size
,
growing_beam
.
size
(
1
))
for
i
in
range
(
is_finished
.
size
(
0
)):
if
is_top_beam_finished
[
i
]:
is_finished
[
i
].
fill_
(
1
)
finished_hyp
=
is_finished
[
i
].
nonzero
().
view
(
-
1
)
# Store finished hypotheses for this batch.
b
=
batch_offset
[
i
]
for
j
in
finished_hyp
:
hypotheses
[
b
].
append
((
topk_scores
[
i
,
j
],
predictions
[
i
,
j
,
:]))
# If the batch reached the end, save the best hypotheses
# in terms of length-penalized score.
if
is_top_beam_finished
[
i
]:
best_hyp
=
sorted
(
hypotheses
[
b
],
key
=
lambda
x
:
x
[
0
],
reverse
=
True
)
best_score
,
best_prediction
=
best_hyp
[
0
]
results
[
"scores"
][
b
].
append
(
best_score
)
results
[
"predictions"
][
b
].
append
(
best_prediction
)
non_finished
=
is_top_beam_finished
.
eq
(
0
).
nonzero
().
view
(
-
1
)
if
len
(
non_finished
)
==
0
:
break
# Remove finished batches for the next step.
topk_log_probabilities
=
topk_log_probabilities
.
index_select
(
0
,
non_finished
)
batch_offset
=
batch_offset
.
index_select
(
0
,
non_finished
)
growing_beam
=
predictions
.
index_select
(
0
,
non_finished
).
view
(
-
1
,
growing_beam
.
size
(
-
1
)
)
# Re-order the state for the next pass
surviving_beams_rows
=
surviving_beams_rows
.
index_select
(
0
,
non_finished
)
kwargs_decoder
[
"encoder_hidden_states"
]
=
kwargs_decoder
[
"encoder_hidden_states"
].
index_select
(
0
,
surviving_beams_rows
)
return
results
def
tile
(
x
,
count
,
dim
=
0
):
"""
Tiles `x` along dimension `dim` `count` times.
Example:
>> ex = torch.tensor([1,2],[3,4])
>> tile(ex, 2, 0)
torch.Tensor([[1,2],[1,2],[3,4],[3,4]])
"""
perm
=
list
(
range
(
len
(
x
.
size
())))
if
dim
!=
0
:
perm
[
0
],
perm
[
dim
]
=
perm
[
dim
],
perm
[
0
]
x
=
x
.
permute
(
perm
).
contiguous
()
out_size
=
list
(
x
.
size
())
out_size
[
0
]
*=
count
batch
=
x
.
size
(
0
)
x
=
(
x
.
view
(
batch
,
-
1
)
.
transpose
(
0
,
1
)
.
repeat
(
count
,
1
)
.
transpose
(
0
,
1
)
.
contiguous
()
.
view
(
*
out_size
)
)
if
dim
!=
0
:
x
=
x
.
permute
(
perm
).
contiguous
()
return
x
transformers/modeling_bert.py
View file @
4c3ac4a7
...
@@ -646,7 +646,7 @@ class BertModel(BertPreTrainedModel):
...
@@ -646,7 +646,7 @@ class BertModel(BertPreTrainedModel):
if
attention_mask
.
dim
()
==
2
:
if
attention_mask
.
dim
()
==
2
:
if
self
.
config
.
is_decoder
:
if
self
.
config
.
is_decoder
:
batch_size
,
seq_length
=
input_ids
.
size
()
batch_size
,
seq_length
=
input_ids
.
size
()
seq_ids
=
torch
.
arange
(
seq_length
)
seq_ids
=
torch
.
arange
(
seq_length
,
device
=
input_ids
.
device
)
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[:,
None
,
None
,
:]
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[:,
None
,
None
,
:]
else
:
else
:
...
@@ -660,6 +660,13 @@ class BertModel(BertPreTrainedModel):
...
@@ -660,6 +660,13 @@ class BertModel(BertPreTrainedModel):
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
# If a 2D encoder attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if
encoder_attention_mask
is
not
None
:
encoder_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
encoder_attention_mask
=
encoder_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
encoder_attention_mask
=
(
1.0
-
encoder_attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# attention_probs has shape bsz x n_heads x N x N
...
@@ -819,7 +826,7 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -819,7 +826,7 @@ class BertForMaskedLM(BertPreTrainedModel):
self
.
bert
.
embeddings
.
word_embeddings
)
self
.
bert
.
embeddings
.
word_embeddings
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
masked_lm_labels
=
None
,
lm_labels
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
):
masked_lm_labels
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
,
lm_labels
=
None
,
):
outputs
=
self
.
bert
(
input_ids
,
outputs
=
self
.
bert
(
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
...
@@ -838,11 +845,8 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -838,11 +845,8 @@ class BertForMaskedLM(BertPreTrainedModel):
# 1. If a tensor that contains the indices of masked labels is provided,
# 1. If a tensor that contains the indices of masked labels is provided,
# the cross-entropy is the MLM cross-entropy that measures the likelihood
# the cross-entropy is the MLM cross-entropy that measures the likelihood
# of predictions for masked words.
# of predictions for masked words.
# 2. If
encoder hidden states are
provided we are in a causal s
ituat
io
n
where we
# 2. If
`lm_label` is
provided we are in a causal s
cenar
io where we
# try to predict the next word for each input in the encoder.
# try to predict the next word for each input in the encoder.
if
masked_lm_labels
is
not
None
and
lm_labels
is
not
None
:
raise
AttributeError
(
"Masked LM training with an encoder-decoder is not supported."
)
if
masked_lm_labels
is
not
None
:
if
masked_lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
# -1 index = padding token
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
# -1 index = padding token
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
...
@@ -851,9 +855,9 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -851,9 +855,9 @@ class BertForMaskedLM(BertPreTrainedModel):
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
# we are doing next-token prediction; shift prediction scores and input ids by one
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores
=
prediction_scores
[:,
:
-
1
,
:]
prediction_scores
=
prediction_scores
[:,
:
-
1
,
:]
lm_labels
=
lm_labels
[:,
1
:,
:]
lm_labels
=
lm_labels
[:,
1
:]
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
seq2seq_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
lm_labels
.
view
(
-
1
))
seq2seq_loss
=
loss_fct
(
prediction_scores
.
reshape
(
-
1
,
self
.
config
.
vocab_size
),
lm_labels
.
reshape
(
-
1
))
outputs
=
(
seq2seq_loss
,)
+
outputs
outputs
=
(
seq2seq_loss
,)
+
outputs
return
outputs
# (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)
return
outputs
# (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)
...
...
transformers/modeling_seq2seq.py
View file @
4c3ac4a7
...
@@ -17,13 +17,12 @@
...
@@ -17,13 +17,12 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
import
os
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
.file_utils
import
add_start_docstrings
from
.modeling_auto
import
AutoModel
,
AutoModelWithLMHead
from
.modeling_auto
import
AutoModel
,
AutoModelWithLMHead
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -43,7 +42,13 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -43,7 +42,13 @@ class PreTrainedSeq2seq(nn.Module):
self
.
decoder
=
decoder
self
.
decoder
=
decoder
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
encoder_pretrained_model_name_or_path
=
None
,
decoder_pretrained_model_name_or_path
=
None
,
*
model_args
,
**
kwargs
):
def
from_pretrained
(
cls
,
encoder_pretrained_model_name_or_path
=
None
,
decoder_pretrained_model_name_or_path
=
None
,
*
model_args
,
**
kwargs
):
r
""" Instantiates an encoder and a decoder from one or two base classes
r
""" Instantiates an encoder and a decoder from one or two base classes
of the library from pre-trained model checkpoints.
of the library from pre-trained model checkpoints.
...
@@ -108,23 +113,28 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -108,23 +113,28 @@ class PreTrainedSeq2seq(nn.Module):
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
# decoder-specific it the key starts with `decoder_`
kwargs_decoder
=
{}
kwargs_encoder
=
{
kwargs_encoder
=
kwargs
argument
:
value
for
key
in
kwargs_encoder
.
keys
():
for
argument
,
value
in
kwargs
.
items
()
if
key
.
startswith
(
"decoder_"
):
if
not
argument
.
startswith
(
"decoder_"
)
kwargs_decoder
[
key
.
replace
(
"decoder_"
,
""
)]
=
kwargs_encoder
.
pop
(
key
)
}
kwargs_decoder
=
{
argument
[
len
(
"decoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
}
# Load and initialize the encoder and decoder
# Load and initialize the encoder and decoder
#
The distinction between encoder and decoder at the model level is made
# The distinction between encoder and decoder at the model level is made
#
by the value of the flag `is_decoder` that we need to set correctly.
# by the value of the flag `is_decoder` that we need to set correctly.
encoder
=
kwargs
.
pop
(
"encoder_model"
,
None
)
encoder
=
kwargs
_encoder
.
pop
(
"encoder_model"
,
None
)
if
encoder
is
None
:
if
encoder
is
None
:
kwargs_encoder
[
"is_decoder"
]
=
False
kwargs_encoder
[
"is_decoder"
]
=
False
encoder
=
AutoModel
.
from_pretrained
(
encoder
=
AutoModel
.
from_pretrained
(
encoder_pretrained_model_name_or_path
,
*
model_args
,
**
kwargs_encoder
encoder_pretrained_model_name_or_path
,
*
model_args
,
**
kwargs_encoder
)
)
decoder
=
kwargs
.
pop
(
"
decoder
_
model"
,
None
)
decoder
=
kwargs
_
decoder
.
pop
(
"
model"
,
None
)
if
decoder
is
None
:
if
decoder
is
None
:
kwargs_decoder
[
"is_decoder"
]
=
True
kwargs_decoder
[
"is_decoder"
]
=
True
decoder
=
AutoModelWithLMHead
.
from_pretrained
(
decoder
=
AutoModelWithLMHead
.
from_pretrained
(
...
@@ -135,6 +145,12 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -135,6 +145,12 @@ class PreTrainedSeq2seq(nn.Module):
return
model
return
model
def
save_pretrained
(
self
,
save_directory
):
""" Save a Seq2Seq model and its configuration file in a format
such that it can be loaded using `:func:`~transformers.PreTrainedSeq2seq.from_pretrained` """
self
.
encoder
.
save_pretrained
(
os
.
path
.
join
(
save_directory
,
"encoder"
))
self
.
decoder
.
save_pretrained
(
os
.
path
.
join
(
save_directory
,
"decoder"
))
def
forward
(
self
,
encoder_input_ids
,
decoder_input_ids
,
**
kwargs
):
def
forward
(
self
,
encoder_input_ids
,
decoder_input_ids
,
**
kwargs
):
""" The forward pass on a seq2eq depends what we are performing:
""" The forward pass on a seq2eq depends what we are performing:
...
@@ -155,22 +171,29 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -155,22 +171,29 @@ class PreTrainedSeq2seq(nn.Module):
"""
"""
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
# decoder-specific it the key starts with `decoder_`
kwargs_decoder
=
{}
kwargs_encoder
=
{
kwargs_encoder
=
kwargs
argument
:
value
for
key
in
kwargs_encoder
.
keys
():
for
argument
,
value
in
kwargs
.
items
()
if
key
.
startswith
(
"decoder_"
):
if
not
argument
.
startswith
(
"decoder_"
)
kwargs_decoder
[
key
.
replace
(
"decoder_"
,
""
)]
=
kwargs_encoder
.
pop
(
key
)
}
kwargs_decoder
=
{
argument
[
len
(
"decoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
}
# Encode if needed (training, first prediction pass)
# Encode if needed (training, first prediction pass)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"encoder_hidden_states"
,
None
)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"encoder_hidden_states"
,
None
)
if
encoder_hidden_states
is
None
:
if
encoder_hidden_states
is
None
:
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
][
-
1
]
# output of the encoder *stack*
encoder_hidden_states
=
encoder_outputs
[
0
][
-
1
]
# output of the encoder *stack*
else
:
else
:
encoder_outputs
=
()
encoder_outputs
=
()
# Decode
# Decode
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
[
None
,
:,
:]
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
**
kwargs_decoder
)
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
**
kwargs_decoder
)
return
decoder_outputs
+
encoder_outputs
return
decoder_outputs
+
encoder_outputs
...
@@ -201,9 +224,25 @@ class Model2Model(PreTrainedSeq2seq):
...
@@ -201,9 +224,25 @@ class Model2Model(PreTrainedSeq2seq):
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
args
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
args
,
**
kwargs
):
model
=
super
(
Model2Model
,
cls
).
from_pretrained
(
encoder_pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
decoder_pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
if
(
**
kwargs
)
"bert"
not
in
pretrained_model_name_or_path
or
"roberta"
in
pretrained_model_name_or_path
or
"distilbert"
in
pretrained_model_name_or_path
):
raise
ValueError
(
"Only the Bert model is currently supported."
)
model
=
super
(
Model2Model
,
cls
).
from_pretrained
(
encoder_pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
decoder_pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
# Some architectures require for the decoder to be initialized randomly
# before fine-tuning.
if
kwargs
.
get
(
"decoder_initialize_randomly"
,
False
):
model
.
decoder
.
init_weights
()
return
model
return
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment