Unverified Commit ed37f9fa authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[Bart] _prepare_decoder_inputs should use large negative (#3158)

parent 0416d437
......@@ -65,7 +65,7 @@ BART_INPUTS_DOCSTRING = r"""
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
See diagram 1 in the paper for more info on the default strategy
"""
LARGE_NEGATIVE = -1e4
LARGE_NEGATIVE = -1e8
def _prepare_bart_decoder_inputs(
......@@ -144,18 +144,18 @@ def _check_shapes(shape_1, shape2):
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
def _combine_masks(key_padding_mask, attn_mask, targ_size):
def _combine_masks(key_padding_mask, causal_lm_mask, targ_size):
# targ_size = (bsz, tgt_len, src_len)
a = torch.zeros(targ_size)
b = torch.zeros(targ_size)
if key_padding_mask is not None: # (bsz, tgt_len) -> targ_size
_check_shapes(key_padding_mask.shape, targ_size[:2])
reshaped = key_padding_mask.unsqueeze(2).expand(*targ_size)
a[reshaped] = 1e-8
a[reshaped] = LARGE_NEGATIVE
if attn_mask is not None: # (tgt_len, src_len) -> targ_size
_check_shapes(attn_mask.shape, targ_size[-2:])
b = attn_mask.unsqueeze(0).expand(*targ_size)
if causal_lm_mask is not None: # (tgt_len, src_len) -> targ_size
_check_shapes(causal_lm_mask.shape, targ_size[-2:])
b = causal_lm_mask.unsqueeze(0).expand(*targ_size)
return (a + b).unsqueeze(1).clamp(LARGE_NEGATIVE,)
......
......@@ -37,6 +37,7 @@ if is_torch_available():
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
shift_tokens_right,
_prepare_bart_decoder_inputs,
LARGE_NEGATIVE,
)
from transformers.tokenization_bart import BartTokenizer
......@@ -303,6 +304,38 @@ class BartHeadTests(unittest.TestCase):
lm_model = BartForConditionalGeneration(config).eval().to(torch_device).half()
lm_model.generate(input_ids, attention_mask)
def test_prepare_bart_decoder_inputs(self):
config, *_ = self._get_config_and_data(output_past=False)
input_ids = _long_tensor(([4, 4, 2])) # only used for .device if decoder_input_ids is passed
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
ignore = LARGE_NEGATIVE
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids)
expected_mask = torch.tensor(
[
[0, ignore, ignore],
[0, 0, ignore],
[ignore, ignore, ignore], # never attend to the final token, because its pad
]
).to(input_ids.device)
self.assertEqual(decoder_attn_mask.size(), (1, 1, 3, 3))
self.assertTrue(torch.eq(expected_mask, decoder_attn_mask).all())
# Test no causal mask
config, *_ = self._get_config_and_data(output_past=True)
expected_just_padding_mask = torch.tensor(
[[0, 0, 0], [0, 0, 0], [ignore, ignore, ignore]] # never attend to the final token, because its pad
).to(input_ids.device)
_, decoder_attn_mask_no_causal_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids)
self.assertEqual(decoder_attn_mask_no_causal_mask.size(), (1, 1, 3, 3))
self.assertTrue(torch.eq(expected_just_padding_mask, decoder_attn_mask_no_causal_mask).all())
decoder_input_ids = _long_tensor([[0, 26388, 4133, 2]])
# Attend to everything if no pad tokens and no causal mask
_, decoder_attn_mask_no_padding_no_causal_mask = _prepare_bart_decoder_inputs(
config, input_ids, decoder_input_ids
)
self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all())
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment