Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ed37f9fa
Unverified
Commit
ed37f9fa
authored
Mar 06, 2020
by
Sam Shleifer
Committed by
GitHub
Mar 06, 2020
Browse files
[Bart] _prepare_decoder_inputs should use large negative (#3158)
parent
0416d437
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
6 deletions
+39
-6
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+6
-6
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+33
-0
No files found.
src/transformers/modeling_bart.py
View file @
ed37f9fa
...
@@ -65,7 +65,7 @@ BART_INPUTS_DOCSTRING = r"""
...
@@ -65,7 +65,7 @@ BART_INPUTS_DOCSTRING = r"""
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
See diagram 1 in the paper for more info on the default strategy
See diagram 1 in the paper for more info on the default strategy
"""
"""
LARGE_NEGATIVE
=
-
1e
4
LARGE_NEGATIVE
=
-
1e
8
def
_prepare_bart_decoder_inputs
(
def
_prepare_bart_decoder_inputs
(
...
@@ -144,18 +144,18 @@ def _check_shapes(shape_1, shape2):
...
@@ -144,18 +144,18 @@ def _check_shapes(shape_1, shape2):
raise
AssertionError
(
"shape mismatch: {} != {}"
.
format
(
shape_1
,
shape2
))
raise
AssertionError
(
"shape mismatch: {} != {}"
.
format
(
shape_1
,
shape2
))
def
_combine_masks
(
key_padding_mask
,
attn
_mask
,
targ_size
):
def
_combine_masks
(
key_padding_mask
,
causal_lm
_mask
,
targ_size
):
# targ_size = (bsz, tgt_len, src_len)
# targ_size = (bsz, tgt_len, src_len)
a
=
torch
.
zeros
(
targ_size
)
a
=
torch
.
zeros
(
targ_size
)
b
=
torch
.
zeros
(
targ_size
)
b
=
torch
.
zeros
(
targ_size
)
if
key_padding_mask
is
not
None
:
# (bsz, tgt_len) -> targ_size
if
key_padding_mask
is
not
None
:
# (bsz, tgt_len) -> targ_size
_check_shapes
(
key_padding_mask
.
shape
,
targ_size
[:
2
])
_check_shapes
(
key_padding_mask
.
shape
,
targ_size
[:
2
])
reshaped
=
key_padding_mask
.
unsqueeze
(
2
).
expand
(
*
targ_size
)
reshaped
=
key_padding_mask
.
unsqueeze
(
2
).
expand
(
*
targ_size
)
a
[
reshaped
]
=
1e-8
a
[
reshaped
]
=
LARGE_NEGATIVE
if
attn
_mask
is
not
None
:
# (tgt_len, src_len) -> targ_size
if
causal_lm
_mask
is
not
None
:
# (tgt_len, src_len) -> targ_size
_check_shapes
(
attn
_mask
.
shape
,
targ_size
[
-
2
:])
_check_shapes
(
causal_lm
_mask
.
shape
,
targ_size
[
-
2
:])
b
=
attn
_mask
.
unsqueeze
(
0
).
expand
(
*
targ_size
)
b
=
causal_lm
_mask
.
unsqueeze
(
0
).
expand
(
*
targ_size
)
return
(
a
+
b
).
unsqueeze
(
1
).
clamp
(
LARGE_NEGATIVE
,)
return
(
a
+
b
).
unsqueeze
(
1
).
clamp
(
LARGE_NEGATIVE
,)
...
...
tests/test_modeling_bart.py
View file @
ed37f9fa
...
@@ -37,6 +37,7 @@ if is_torch_available():
...
@@ -37,6 +37,7 @@ if is_torch_available():
BART_PRETRAINED_MODEL_ARCHIVE_MAP
,
BART_PRETRAINED_MODEL_ARCHIVE_MAP
,
shift_tokens_right
,
shift_tokens_right
,
_prepare_bart_decoder_inputs
,
_prepare_bart_decoder_inputs
,
LARGE_NEGATIVE
,
)
)
from
transformers.tokenization_bart
import
BartTokenizer
from
transformers.tokenization_bart
import
BartTokenizer
...
@@ -303,6 +304,38 @@ class BartHeadTests(unittest.TestCase):
...
@@ -303,6 +304,38 @@ class BartHeadTests(unittest.TestCase):
lm_model
=
BartForConditionalGeneration
(
config
).
eval
().
to
(
torch_device
).
half
()
lm_model
=
BartForConditionalGeneration
(
config
).
eval
().
to
(
torch_device
).
half
()
lm_model
.
generate
(
input_ids
,
attention_mask
)
lm_model
.
generate
(
input_ids
,
attention_mask
)
def
test_prepare_bart_decoder_inputs
(
self
):
config
,
*
_
=
self
.
_get_config_and_data
(
output_past
=
False
)
input_ids
=
_long_tensor
(([
4
,
4
,
2
]))
# only used for .device if decoder_input_ids is passed
decoder_input_ids
=
_long_tensor
([[
26388
,
2
,
config
.
pad_token_id
]])
ignore
=
LARGE_NEGATIVE
decoder_input_ids
,
decoder_attn_mask
=
_prepare_bart_decoder_inputs
(
config
,
input_ids
,
decoder_input_ids
)
expected_mask
=
torch
.
tensor
(
[
[
0
,
ignore
,
ignore
],
[
0
,
0
,
ignore
],
[
ignore
,
ignore
,
ignore
],
# never attend to the final token, because its pad
]
).
to
(
input_ids
.
device
)
self
.
assertEqual
(
decoder_attn_mask
.
size
(),
(
1
,
1
,
3
,
3
))
self
.
assertTrue
(
torch
.
eq
(
expected_mask
,
decoder_attn_mask
).
all
())
# Test no causal mask
config
,
*
_
=
self
.
_get_config_and_data
(
output_past
=
True
)
expected_just_padding_mask
=
torch
.
tensor
(
[[
0
,
0
,
0
],
[
0
,
0
,
0
],
[
ignore
,
ignore
,
ignore
]]
# never attend to the final token, because its pad
).
to
(
input_ids
.
device
)
_
,
decoder_attn_mask_no_causal_mask
=
_prepare_bart_decoder_inputs
(
config
,
input_ids
,
decoder_input_ids
)
self
.
assertEqual
(
decoder_attn_mask_no_causal_mask
.
size
(),
(
1
,
1
,
3
,
3
))
self
.
assertTrue
(
torch
.
eq
(
expected_just_padding_mask
,
decoder_attn_mask_no_causal_mask
).
all
())
decoder_input_ids
=
_long_tensor
([[
0
,
26388
,
4133
,
2
]])
# Attend to everything if no pad tokens and no causal mask
_
,
decoder_attn_mask_no_padding_no_causal_mask
=
_prepare_bart_decoder_inputs
(
config
,
input_ids
,
decoder_input_ids
)
self
.
assertTrue
(
torch
.
eq
(
decoder_attn_mask_no_padding_no_causal_mask
,
0
).
all
())
def
_assert_tensors_equal
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
def
_assert_tensors_equal
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment