Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1a5aefc9
Unverified
Commit
1a5aefc9
authored
Mar 26, 2020
by
Sam Shleifer
Committed by
GitHub
Mar 26, 2020
Browse files
[Seq2Seq Generation] Call encoder before expanding input_ids (#3370)
parent
39371ee4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
15 deletions
+29
-15
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+1
-1
src/transformers/modeling_t5.py
src/transformers/modeling_t5.py
+1
-0
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+27
-14
No files found.
src/transformers/modeling_bart.py
View file @
1a5aefc9
...
@@ -113,6 +113,7 @@ class PretrainedBartModel(PreTrainedModel):
...
@@ -113,6 +113,7 @@ class PretrainedBartModel(PreTrainedModel):
config_class
=
BartConfig
config_class
=
BartConfig
base_model_prefix
=
"model"
base_model_prefix
=
"model"
pretrained_model_archive_map
=
BART_PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
BART_PRETRAINED_MODEL_ARCHIVE_MAP
encoder_outputs_batch_dim_idx
=
1
# outputs shaped (seq_len, bs, ...)
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
init_std
std
=
self
.
config
.
init_std
...
@@ -888,7 +889,6 @@ class BartForConditionalGeneration(PretrainedBartModel):
...
@@ -888,7 +889,6 @@ class BartForConditionalGeneration(PretrainedBartModel):
encoder_outputs
,
decoder_cached_states
=
past
,
None
encoder_outputs
,
decoder_cached_states
=
past
,
None
else
:
else
:
encoder_outputs
,
decoder_cached_states
=
past
encoder_outputs
,
decoder_cached_states
=
past
return
{
return
{
"input_ids"
:
None
,
# encoder_outputs is defined. input_ids not needed
"input_ids"
:
None
,
# encoder_outputs is defined. input_ids not needed
"encoder_outputs"
:
encoder_outputs
,
"encoder_outputs"
:
encoder_outputs
,
...
...
src/transformers/modeling_t5.py
View file @
1a5aefc9
...
@@ -457,6 +457,7 @@ class T5PreTrainedModel(PreTrainedModel):
...
@@ -457,6 +457,7 @@ class T5PreTrainedModel(PreTrainedModel):
pretrained_model_archive_map
=
T5_PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
T5_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_t5
load_tf_weights
=
load_tf_weights_in_t5
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
encoder_outputs_batch_dim_idx
=
0
# outputs shaped (bs, ...)
@
property
@
property
def
dummy_inputs
(
self
):
def
dummy_inputs
(
self
):
...
...
src/transformers/modeling_utils.py
View file @
1a5aefc9
...
@@ -895,6 +895,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -895,6 +895,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
effective_batch_size
=
batch_size
effective_batch_size
=
batch_size
effective_batch_mult
=
1
effective_batch_mult
=
1
if
self
.
config
.
is_encoder_decoder
:
if
decoder_start_token_id
is
None
:
decoder_start_token_id
=
bos_token_id
assert
(
decoder_start_token_id
is
not
None
),
"decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
assert
hasattr
(
self
,
"get_encoder"
),
"{} should have a 'get_encoder' function defined"
.
format
(
self
)
assert
callable
(
self
.
get_encoder
),
"{} should be a method"
.
format
(
self
.
get_encoder
)
# get encoder and store encoder outputs
encoder
=
self
.
get_encoder
()
encoder_outputs
=
encoder
(
input_ids
,
attention_mask
=
attention_mask
)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if
num_return_sequences
>
1
or
num_beams
>
1
:
if
num_return_sequences
>
1
or
num_beams
>
1
:
input_ids_len
=
input_ids
.
shape
[
-
1
]
input_ids_len
=
input_ids
.
shape
[
-
1
]
...
@@ -911,20 +926,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -911,20 +926,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
)
# shape: (batch_size * num_return_sequences * num_beams, cur_len)
)
# shape: (batch_size * num_return_sequences * num_beams, cur_len)
if
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
:
if
decoder_start_token_id
is
None
:
decoder_start_token_id
=
bos_token_id
assert
(
decoder_start_token_id
is
not
None
),
"decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
assert
hasattr
(
self
,
"get_encoder"
),
"{} should have a 'get_encoder' function defined"
.
format
(
self
)
assert
callable
(
self
.
get_encoder
),
"{} should be a method"
.
format
(
self
.
get_encoder
)
# get encoder and store encoder outputs
encoder
=
self
.
get_encoder
()
encoder_outputs
=
encoder
(
input_ids
,
attention_mask
=
attention_mask
)
# create empty decoder_input_ids
# create empty decoder_input_ids
input_ids
=
torch
.
full
(
input_ids
=
torch
.
full
(
(
effective_batch_size
*
num_beams
,
1
),
(
effective_batch_size
*
num_beams
,
1
),
...
@@ -933,6 +934,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -933,6 +934,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
device
=
next
(
self
.
parameters
()).
device
,
device
=
next
(
self
.
parameters
()).
device
,
)
)
cur_len
=
1
cur_len
=
1
batch_idx
=
self
.
encoder_outputs_batch_dim_idx
assert
(
batch_size
==
encoder_outputs
[
0
].
shape
[
batch_idx
]
),
f
"expected encoder_outputs[0] to have 1st dimension bs=
{
batch_size
}
, got
{
encoder_outputs
[
0
].
shape
[
1
]
}
"
expanded_idx
=
(
torch
.
arange
(
batch_size
)
.
view
(
-
1
,
1
)
.
repeat
(
1
,
num_beams
*
effective_batch_mult
)
.
view
(
-
1
)
.
to
(
input_ids
.
device
)
)
encoder_outputs
=
(
encoder_outputs
[
0
].
index_select
(
batch_idx
,
expanded_idx
),
*
encoder_outputs
[
1
:])
else
:
else
:
encoder_outputs
=
None
encoder_outputs
=
None
cur_len
=
input_ids
.
shape
[
-
1
]
cur_len
=
input_ids
.
shape
[
-
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment