Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
390c1285
Unverified
Commit
390c1285
authored
Apr 02, 2020
by
Patrick von Platen
Committed by
GitHub
Apr 02, 2020
Browse files
[Encoder-Decoder] Force models outputs to always have batch_size as their first dim (#3536)
* solve conflicts * improve comments
parent
ab5d06a0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
12 deletions
+20
-12
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+12
-6
src/transformers/modeling_t5.py
src/transformers/modeling_t5.py
+0
-1
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+8
-5
No files found.
src/transformers/modeling_bart.py
View file @
390c1285
...
@@ -116,7 +116,6 @@ class PretrainedBartModel(PreTrainedModel):
...
@@ -116,7 +116,6 @@ class PretrainedBartModel(PreTrainedModel):
config_class
=
BartConfig
config_class
=
BartConfig
base_model_prefix
=
"model"
base_model_prefix
=
"model"
pretrained_model_archive_map
=
BART_PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
BART_PRETRAINED_MODEL_ARCHIVE_MAP
encoder_outputs_batch_dim_idx
=
1
# outputs shaped (seq_len, bs, ...)
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
init_std
std
=
self
.
config
.
init_std
...
@@ -294,7 +293,10 @@ class BartEncoder(nn.Module):
...
@@ -294,7 +293,10 @@ class BartEncoder(nn.Module):
if
self
.
output_hidden_states
:
if
self
.
output_hidden_states
:
encoder_states
.
append
(
x
)
encoder_states
.
append
(
x
)
# T x B x C -> B x T x C
encoder_states
=
[
hidden_state
.
transpose
(
0
,
1
)
for
hidden_state
in
encoder_states
]
encoder_states
=
[
hidden_state
.
transpose
(
0
,
1
)
for
hidden_state
in
encoder_states
]
x
=
x
.
transpose
(
0
,
1
)
return
x
,
encoder_states
,
all_attentions
return
x
,
encoder_states
,
all_attentions
...
@@ -448,7 +450,11 @@ class BartDecoder(nn.Module):
...
@@ -448,7 +450,11 @@ class BartDecoder(nn.Module):
x
=
self
.
layernorm_embedding
(
x
)
x
=
self
.
layernorm_embedding
(
x
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
x
.
transpose
(
0
,
1
)
# (seq_len, BS, model_dim)
# Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
x
=
x
.
transpose
(
0
,
1
)
encoder_hidden_states
=
encoder_hidden_states
.
transpose
(
0
,
1
)
# decoder layers
# decoder layers
all_hidden_states
=
()
all_hidden_states
=
()
all_self_attns
=
()
all_self_attns
=
()
...
@@ -477,9 +483,10 @@ class BartDecoder(nn.Module):
...
@@ -477,9 +483,10 @@ class BartDecoder(nn.Module):
if
self
.
output_attentions
:
if
self
.
output_attentions
:
all_self_attns
+=
(
layer_self_attn
,)
all_self_attns
+=
(
layer_self_attn
,)
# Convert
shapes from
(seq_len, BS, model_dim)
to
(BS, seq_len, model_dim)
# Convert
to standart output format:
(seq_len, BS, model_dim)
->
(BS, seq_len, model_dim)
all_hidden_states
=
[
hidden_state
.
transpose
(
0
,
1
)
for
hidden_state
in
all_hidden_states
]
all_hidden_states
=
[
hidden_state
.
transpose
(
0
,
1
)
for
hidden_state
in
all_hidden_states
]
x
=
x
.
transpose
(
0
,
1
)
x
=
x
.
transpose
(
0
,
1
)
encoder_hidden_states
=
encoder_hidden_states
.
transpose
(
0
,
1
)
if
self
.
output_past
:
if
self
.
output_past
:
next_cache
=
((
encoder_hidden_states
,
encoder_padding_mask
),
next_decoder_cache
)
next_cache
=
((
encoder_hidden_states
,
encoder_padding_mask
),
next_decoder_cache
)
...
@@ -930,10 +937,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
...
@@ -930,10 +937,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
layer_past_new
=
{
layer_past_new
=
{
attn_key
:
_reorder_buffer
(
attn_cache
,
beam_idx
)
for
attn_key
,
attn_cache
in
layer_past
.
items
()
attn_key
:
_reorder_buffer
(
attn_cache
,
beam_idx
)
for
attn_key
,
attn_cache
in
layer_past
.
items
()
}
}
# reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
# reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
reordered_past
.
append
(
layer_past_new
)
reordered_past
.
append
(
layer_past_new
)
new_enc_out
=
enc_out
if
enc_out
is
None
else
enc_out
.
index_select
(
1
,
beam_idx
)
new_enc_out
=
enc_out
if
enc_out
is
None
else
enc_out
.
index_select
(
0
,
beam_idx
)
new_enc_mask
=
enc_mask
if
enc_mask
is
None
else
enc_mask
.
index_select
(
0
,
beam_idx
)
new_enc_mask
=
enc_mask
if
enc_mask
is
None
else
enc_mask
.
index_select
(
0
,
beam_idx
)
past
=
((
new_enc_out
,
new_enc_mask
),
reordered_past
)
past
=
((
new_enc_out
,
new_enc_mask
),
reordered_past
)
...
...
src/transformers/modeling_t5.py
View file @
390c1285
...
@@ -457,7 +457,6 @@ class T5PreTrainedModel(PreTrainedModel):
...
@@ -457,7 +457,6 @@ class T5PreTrainedModel(PreTrainedModel):
pretrained_model_archive_map
=
T5_PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
T5_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_t5
load_tf_weights
=
load_tf_weights_in_t5
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
encoder_outputs_batch_dim_idx
=
0
# outputs shaped (bs, ...)
@
property
@
property
def
dummy_inputs
(
self
):
def
dummy_inputs
(
self
):
...
...
src/transformers/modeling_utils.py
View file @
390c1285
...
@@ -948,18 +948,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -948,18 +948,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
device
=
next
(
self
.
parameters
()).
device
,
device
=
next
(
self
.
parameters
()).
device
,
)
)
cur_len
=
1
cur_len
=
1
batch_idx
=
self
.
encoder_outputs_batch_dim_idx
assert
(
assert
(
batch_size
==
encoder_outputs
[
0
].
shape
[
batch_idx
]
batch_size
==
encoder_outputs
[
0
].
shape
[
0
]
),
f
"expected encoder_outputs[0] to have 1st dimension bs=
{
batch_size
}
, got
{
encoder_outputs
[
0
].
shape
[
1
]
}
"
),
f
"expected encoder_outputs[0] to have 1st dimension bs=
{
batch_size
}
, got
{
encoder_outputs
[
0
].
shape
[
0
]
}
"
expanded_idx
=
(
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
expanded_batch_idxs
=
(
torch
.
arange
(
batch_size
)
torch
.
arange
(
batch_size
)
.
view
(
-
1
,
1
)
.
view
(
-
1
,
1
)
.
repeat
(
1
,
num_beams
*
effective_batch_mult
)
.
repeat
(
1
,
num_beams
*
effective_batch_mult
)
.
view
(
-
1
)
.
view
(
-
1
)
.
to
(
input_ids
.
device
)
.
to
(
input_ids
.
device
)
)
)
encoder_outputs
=
(
encoder_outputs
[
0
].
index_select
(
batch_idx
,
expanded_idx
),
*
encoder_outputs
[
1
:])
# expand encoder_outputs
encoder_outputs
=
(
encoder_outputs
[
0
].
index_select
(
0
,
expanded_batch_idxs
),
*
encoder_outputs
[
1
:])
else
:
else
:
encoder_outputs
=
None
encoder_outputs
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment