Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3cf2020c
Commit
3cf2020c
authored
Oct 30, 2019
by
Rémi Louf
Browse files
change kwargs processing
parent
a88a0e44
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
28 deletions
+43
-28
transformers/modeling_encoder_decoder.py
transformers/modeling_encoder_decoder.py
+43
-28
No files found.
transformers/modeling_encoder_decoder.py
View file @
3cf2020c
...
...
@@ -114,23 +114,28 @@ class PreTrainedEncoderDecoder(nn.Module):
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as a whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_
encoder
=
{
argument
[
len
(
"encoder_"
):]
:
value
kwargs_
common
=
{
argument
:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"encoder_"
)
if
not
argument
.
startswith
(
"encoder_"
)
and
not
argument
.
startswith
(
"decoder_"
)
}
kwargs_decoder
=
{
argument
[
len
(
"decoder_"
):]:
value
kwargs_decoder
=
kwargs_common
.
copy
()
kwargs_encoder
=
kwargs_common
.
copy
()
kwargs_encoder
.
update
(
{
argument
[
len
(
"encoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"
d
ecoder_"
)
if
argument
.
startswith
(
"e
n
coder_"
)
}
kwargs_common
=
{
argument
:
value
)
kwargs_decoder
.
update
(
{
argument
[
len
(
"decoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
if
not
(
argument
.
startswith
(
"encoder_"
)
or
argument
.
startswith
(
"decoder_"
)
)
if
argument
.
startswith
(
"decoder_"
)
}
kwargs_decoder
=
dict
(
kwargs_common
,
**
kwargs_decoder
)
kwargs_encoder
=
dict
(
kwargs_common
,
**
kwargs_encoder
)
)
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
...
...
@@ -185,35 +190,44 @@ class PreTrainedEncoderDecoder(nn.Module):
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_
encoder
=
{
argument
[
len
(
"encoder_"
):]
:
value
kwargs_
common
=
{
argument
:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"encoder_"
)
if
not
argument
.
startswith
(
"encoder_"
)
and
not
argument
.
startswith
(
"decoder_"
)
}
kwargs_decoder
=
{
argument
[
len
(
"decoder_"
):]:
value
kwargs_decoder
=
kwargs_common
.
copy
()
kwargs_encoder
=
kwargs_common
.
copy
()
kwargs_encoder
.
update
(
{
argument
[
len
(
"encoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"
d
ecoder_"
)
if
argument
.
startswith
(
"e
n
coder_"
)
}
kwargs_common
=
{
argument
:
value
)
kwargs_decoder
.
update
(
{
argument
[
len
(
"decoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
if
not
(
argument
.
startswith
(
"encoder_"
)
or
argument
.
startswith
(
"decoder_"
)
)
if
argument
.
startswith
(
"decoder_"
)
}
kwargs_decoder
=
dict
(
kwargs_common
,
**
kwargs_decoder
)
kwargs_encoder
=
dict
(
kwargs_common
,
**
kwargs_encoder
)
)
# Encode if needed (training, first prediction pass)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"hidden_states"
,
None
)
if
encoder_hidden_states
is
None
:
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
]
# output the last layer hidden state
encoder_hidden_states
=
encoder_outputs
[
0
]
# output the last layer hidden state
else
:
encoder_outputs
=
()
# Decode
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
kwargs_decoder
[
"encoder_attention_mask"
]
=
kwargs_encoder
.
get
(
"attention_mask"
,
None
)
kwargs_decoder
[
"encoder_attention_mask"
]
=
kwargs_encoder
.
get
(
"attention_mask"
,
None
)
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
**
kwargs_decoder
)
return
decoder_outputs
+
encoder_outputs
...
...
@@ -235,6 +249,7 @@ class Model2Model(PreTrainedEncoderDecoder):
decoder = BertForMaskedLM(config)
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder)
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
Model2Model
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
tie_weights
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment