Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
11573231
Unverified
Commit
11573231
authored
Mar 16, 2020
by
Sam Shleifer
Committed by
GitHub
Mar 16, 2020
Browse files
[BART] generation_mode as a kwarg not a class attribute (#3278)
parent
de697935
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
8 deletions
+9
-8
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+9
-4
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+0
-4
No files found.
src/transformers/modeling_bart.py
View file @
11573231
...
...
@@ -437,7 +437,6 @@ class BartDecoder(nn.Module):
[
DecoderLayer
(
config
)
for
_
in
range
(
config
.
decoder_layers
)]
)
# type: List[DecoderLayer]
self
.
layernorm_embedding
=
LayerNorm
(
config
.
d_model
)
self
.
generation_mode
=
False
def
forward
(
self
,
...
...
@@ -446,6 +445,7 @@ class BartDecoder(nn.Module):
encoder_padding_mask
,
combined_mask
,
decoder_cached_states
=
None
,
generation_mode
=
False
,
**
unused
):
"""
...
...
@@ -474,9 +474,9 @@ class BartDecoder(nn.Module):
assert
encoder_padding_mask
.
max
()
<=
0
# embed positions
positions
=
self
.
embed_positions
(
input_ids
,
generation_mode
=
self
.
generation_mode
)
positions
=
self
.
embed_positions
(
input_ids
,
generation_mode
=
generation_mode
)
if
self
.
generation_mode
:
if
generation_mode
:
input_ids
=
input_ids
[:,
-
1
:]
positions
=
positions
[:,
-
1
:]
# happens after we embed them
assert
input_ids
.
ne
(
self
.
padding_idx
).
any
()
...
...
@@ -820,10 +820,11 @@ class BartModel(PretrainedBartModel):
encoder_outputs
=
None
,
# type: Tuple
decoder_attention_mask
=
None
,
decoder_cached_states
=
None
,
generation_mode
=
False
,
):
# make masks if user doesn't supply
if
not
self
.
decoder
.
generation_mode
:
if
not
generation_mode
:
decoder_input_ids
,
decoder_attention_mask
=
_prepare_bart_decoder_inputs
(
self
.
config
,
input_ids
,
...
...
@@ -842,6 +843,7 @@ class BartModel(PretrainedBartModel):
attention_mask
,
decoder_attention_mask
,
decoder_cached_states
=
decoder_cached_states
,
generation_mode
=
generation_mode
,
)
# Attention and hidden_states will be [] or None if they aren't needed
decoder_outputs
=
_filter_out_falsey_values
(
decoder_outputs
)
# type: tuple
...
...
@@ -886,6 +888,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_attention_mask
=
None
,
decoder_cached_states
=
None
,
lm_labels
=
None
,
generation_mode
=
False
,
**
unused
):
r
"""
...
...
@@ -936,6 +939,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
encoder_outputs
=
encoder_outputs
,
decoder_attention_mask
=
decoder_attention_mask
,
decoder_cached_states
=
decoder_cached_states
,
generation_mode
=
generation_mode
,
)
lm_logits
=
self
.
lm_head
(
outputs
[
0
])
outputs
=
(
lm_logits
,)
+
outputs
[
1
:]
# Add hidden states and attention if they are here
...
...
@@ -963,6 +967,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
"decoder_cached_states"
:
decoder_cached_states
,
"decoder_input_ids"
:
decoder_input_ids
,
"attention_mask"
:
attention_mask
,
"generation_mode"
:
True
,
}
def
prepare_scores_for_generation
(
self
,
scores
,
cur_len
,
max_length
):
...
...
src/transformers/modeling_utils.py
View file @
11573231
...
...
@@ -846,7 +846,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
attention_mask
=
attention_mask
.
contiguous
().
view
(
effective_batch_size
*
num_beams
,
input_ids_len
)
# shape: (batch_size * num_return_sequences * num_beams, cur_len)
if
self
.
config
.
is_encoder_decoder
:
assert
bos_token_id
is
not
None
,
"Encoder Decoder Models need to have a bos_token_id"
# encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs
...
...
@@ -859,9 +858,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
)
cur_len
=
1
# put model in generation mode if it has one
if
hasattr
(
self
.
model
,
"decoder"
)
and
hasattr
(
self
.
model
.
decoder
,
"generation_mode"
):
self
.
model
.
decoder
.
generation_mode
=
True
else
:
encoder_inputs
=
None
cur_len
=
input_ids
.
shape
[
-
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment