Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
95ec1d08
Commit
95ec1d08
authored
Oct 16, 2019
by
Rémi Louf
Browse files
separate inputs into encoder & decoder inputs
parent
e4e0ee14
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
3 deletions
+8
-3
transformers/modeling_seq2seq.py
transformers/modeling_seq2seq.py
+8
-3
No files found.
transformers/modeling_seq2seq.py
View file @
95ec1d08
...
@@ -130,7 +130,7 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -130,7 +130,7 @@ class PreTrainedSeq2seq(nn.Module):
return
model
return
model
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
encoder_input_ids
,
decoder_
input
_id
s
,
**
kwargs
):
""" The forward pass on a seq2eq depends what we are performing:
""" The forward pass on a seq2eq depends what we are performing:
- During training we perform one forward pass through both the encoder
- During training we perform one forward pass through both the encoder
...
@@ -142,6 +142,11 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -142,6 +142,11 @@ class PreTrainedSeq2seq(nn.Module):
Therefore, we skip the forward pass on the encoder if an argument named
Therefore, we skip the forward pass on the encoder if an argument named
`encoder_hidden_state` is passed to this function.
`encoder_hidden_state` is passed to this function.
Params:
encoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
Indices of encoder input sequence tokens in the vocabulary.
decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
Indices of decoder input sequence tokens in the vocabulary.
"""
"""
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
# decoder-specific it the key starts with `decoder_`
...
@@ -154,14 +159,14 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -154,14 +159,14 @@ class PreTrainedSeq2seq(nn.Module):
# Encode if needed (training, first prediction pass)
# Encode if needed (training, first prediction pass)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
'encoder_hidden_states'
,
None
)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
'encoder_hidden_states'
,
None
)
if
encoder_hidden_states
is
None
:
if
encoder_hidden_states
is
None
:
encoder_outputs
=
self
.
encoder
(
*
inputs
,
**
kwargs_encoder
)
encoder_outputs
=
self
.
encoder
(
encoder_
input
_id
s
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
]
encoder_hidden_states
=
encoder_outputs
[
0
]
else
:
else
:
encoder_outputs
=
()
encoder_outputs
=
()
# Decode
# Decode
kwargs_decoder
[
'encoder_hidden_states'
]
=
encoder_hidden_states
kwargs_decoder
[
'encoder_hidden_states'
]
=
encoder_hidden_states
decoder_outputs
=
self
.
decoder
(
**
kwargs_decoder
)
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
**
kwargs_decoder
)
return
decoder_outputs
+
encoder_outputs
return
decoder_outputs
+
encoder_outputs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment