Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3cf2020c
"tests/optimization_test.py" did not exist on "158e82e061c02fc2f1613adb7ac1d1cb6adae71c"
Commit
3cf2020c
authored
Oct 30, 2019
by
Rémi Louf
Browse files
change kwargs processing
parent
a88a0e44
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
28 deletions
+43
-28
transformers/modeling_encoder_decoder.py
transformers/modeling_encoder_decoder.py
+43
-28
No files found.
transformers/modeling_encoder_decoder.py
View file @
3cf2020c
...
@@ -114,23 +114,28 @@ class PreTrainedEncoderDecoder(nn.Module):
...
@@ -114,23 +114,28 @@ class PreTrainedEncoderDecoder(nn.Module):
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as a whole.
# that apply to the model as a whole.
# We let the specific kwargs override the common ones in case of conflict.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_
encoder
=
{
kwargs_
common
=
{
argument
[
len
(
"encoder_"
):]
:
value
argument
:
value
for
argument
,
value
in
kwargs
.
items
()
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"encoder_"
)
if
not
argument
.
startswith
(
"encoder_"
)
and
not
argument
.
startswith
(
"decoder_"
)
}
}
kwargs_decoder
=
{
kwargs_decoder
=
kwargs_common
.
copy
()
argument
[
len
(
"decoder_"
):]:
value
kwargs_encoder
=
kwargs_common
.
copy
()
kwargs_encoder
.
update
(
{
argument
[
len
(
"encoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"
d
ecoder_"
)
if
argument
.
startswith
(
"e
n
coder_"
)
}
}
kwargs_common
=
{
)
argument
:
value
kwargs_decoder
.
update
(
{
argument
[
len
(
"decoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
for
argument
,
value
in
kwargs
.
items
()
if
not
(
argument
.
startswith
(
"encoder_"
)
or
argument
.
startswith
(
"decoder_"
)
)
if
argument
.
startswith
(
"decoder_"
)
}
}
kwargs_decoder
=
dict
(
kwargs_common
,
**
kwargs_decoder
)
)
kwargs_encoder
=
dict
(
kwargs_common
,
**
kwargs_encoder
)
# Load and initialize the encoder and decoder
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# The distinction between encoder and decoder at the model level is made
...
@@ -185,35 +190,44 @@ class PreTrainedEncoderDecoder(nn.Module):
...
@@ -185,35 +190,44 @@ class PreTrainedEncoderDecoder(nn.Module):
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_
encoder
=
{
kwargs_
common
=
{
argument
[
len
(
"encoder_"
):]
:
value
argument
:
value
for
argument
,
value
in
kwargs
.
items
()
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"encoder_"
)
if
not
argument
.
startswith
(
"encoder_"
)
and
not
argument
.
startswith
(
"decoder_"
)
}
}
kwargs_decoder
=
{
kwargs_decoder
=
kwargs_common
.
copy
()
argument
[
len
(
"decoder_"
):]:
value
kwargs_encoder
=
kwargs_common
.
copy
()
kwargs_encoder
.
update
(
{
argument
[
len
(
"encoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"
d
ecoder_"
)
if
argument
.
startswith
(
"e
n
coder_"
)
}
}
kwargs_common
=
{
)
argument
:
value
kwargs_decoder
.
update
(
{
argument
[
len
(
"decoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
for
argument
,
value
in
kwargs
.
items
()
if
not
(
argument
.
startswith
(
"encoder_"
)
or
argument
.
startswith
(
"decoder_"
)
)
if
argument
.
startswith
(
"decoder_"
)
}
}
kwargs_decoder
=
dict
(
kwargs_common
,
**
kwargs_decoder
)
)
kwargs_encoder
=
dict
(
kwargs_common
,
**
kwargs_encoder
)
# Encode if needed (training, first prediction pass)
# Encode if needed (training, first prediction pass)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"hidden_states"
,
None
)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"hidden_states"
,
None
)
if
encoder_hidden_states
is
None
:
if
encoder_hidden_states
is
None
:
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
]
# output the last layer hidden state
encoder_hidden_states
=
encoder_outputs
[
0
]
# output the last layer hidden state
else
:
else
:
encoder_outputs
=
()
encoder_outputs
=
()
# Decode
# Decode
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
kwargs_decoder
[
"encoder_attention_mask"
]
=
kwargs_encoder
.
get
(
"attention_mask"
,
None
)
kwargs_decoder
[
"encoder_attention_mask"
]
=
kwargs_encoder
.
get
(
"attention_mask"
,
None
)
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
**
kwargs_decoder
)
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
**
kwargs_decoder
)
return
decoder_outputs
+
encoder_outputs
return
decoder_outputs
+
encoder_outputs
...
@@ -235,6 +249,7 @@ class Model2Model(PreTrainedEncoderDecoder):
...
@@ -235,6 +249,7 @@ class Model2Model(PreTrainedEncoderDecoder):
decoder = BertForMaskedLM(config)
decoder = BertForMaskedLM(config)
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder)
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder)
"""
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
Model2Model
,
self
).
__init__
(
*
args
,
**
kwargs
)
super
(
Model2Model
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
tie_weights
()
self
.
tie_weights
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment