Unverified Commit 3ee431dd authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[Bart/Memory] Two separate, smaller decoder attention masks (#3371)

parent 53fe7338
...@@ -74,39 +74,37 @@ BART_INPUTS_DOCSTRING = r""" ...@@ -74,39 +74,37 @@ BART_INPUTS_DOCSTRING = r"""
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper. Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper.
decoder_attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, 1, tgt_seq_len, tgt_seq_len)`, `optional`, defaults to :obj:`None`): decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens and future tokens, as in the paper. Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify. If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
See diagram 1 in the paper for more info on the default strategy See diagram 1 in the paper for more info on the default strategy
""" """
LARGE_NEGATIVE = -1e8
def invert_mask(attention_mask):
assert attention_mask.dim() == 2
return attention_mask.eq(0)
def _prepare_bart_decoder_inputs( def _prepare_bart_decoder_inputs(
config, input_ids, decoder_input_ids=None, decoder_attn_mask=None, mask_dtype=None, config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
): ):
"""Prepare masks that ignore padding tokens in the decoder and a causal lm mask for the decoder if """Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if
none are provided. This mimics the default behavior in fairseq. To override it pass in masks. none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
Note: this is not called during generation Note: this is not called during generation
""" """
pad_token_id = config.pad_token_id pad_token_id = config.pad_token_id
need_causal_mask = not config.output_past
if decoder_input_ids is None: if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(input_ids, pad_token_id) decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
bsz, tgt_len = decoder_input_ids.size()[:2] bsz, tgt_len = decoder_input_ids.size()
if decoder_attn_mask is None: if decoder_padding_mask is None:
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id) decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
if need_causal_mask: else:
causal_lm_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1) decoder_padding_mask = invert_mask(decoder_padding_mask)
else: causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
causal_lm_mask = None dtype=causal_mask_dtype, device=decoder_input_ids.device
new_shape = (bsz, tgt_len, tgt_len) )
# make it broadcastable so can just be added to the attention coefficients return decoder_input_ids, decoder_padding_mask, causal_mask
decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape).to(device=input_ids.device)
if mask_dtype is not None:
decoder_attn_mask = decoder_attn_mask.to(mask_dtype)
assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len)
return decoder_input_ids, decoder_attn_mask
class PretrainedBartModel(PreTrainedModel): class PretrainedBartModel(PreTrainedModel):
...@@ -130,12 +128,9 @@ class PretrainedBartModel(PreTrainedModel): ...@@ -130,12 +128,9 @@ class PretrainedBartModel(PreTrainedModel):
def dummy_inputs(self): def dummy_inputs(self):
pad_token = self.config.pad_token_id pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(self.config, input_ids)
dummy_inputs = { dummy_inputs = {
"decoder_input_ids": decoder_input_ids,
"attention_mask": input_ids.ne(pad_token), "attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids, "input_ids": input_ids,
"decoder_attention_mask": decoder_attn_mask,
} }
return dummy_inputs return dummy_inputs
...@@ -153,21 +148,6 @@ def _check_shapes(shape_1, shape2): ...@@ -153,21 +148,6 @@ def _check_shapes(shape_1, shape2):
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2)) raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
def _combine_masks(key_padding_mask, causal_lm_mask, targ_size):
"""Make one mask of shape (bsz, 1, tgt_len, src_len) """
a = torch.zeros(targ_size) # targ_size is(bsz, tgt_len, src_len)
b = torch.zeros(targ_size)
if key_padding_mask is not None: # (bsz, tgt_len) -> targ_size
_check_shapes(key_padding_mask.shape, targ_size[:2])
reshaped = key_padding_mask.unsqueeze(2).expand(*targ_size)
a[reshaped] = LARGE_NEGATIVE
if causal_lm_mask is not None: # (tgt_len, src_len) -> targ_size
_check_shapes(causal_lm_mask.shape, targ_size[-2:])
b = causal_lm_mask.unsqueeze(0).expand(*targ_size)
return (a + b).unsqueeze(1).clamp(LARGE_NEGATIVE,)
def shift_tokens_right(input_ids, pad_token_id): def shift_tokens_right(input_ids, pad_token_id):
"""Shift input ids one token to the right, and wrap the last non pad token (usually <eos>).""" """Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
prev_output_tokens = input_ids.clone() prev_output_tokens = input_ids.clone()
...@@ -281,8 +261,7 @@ class BartEncoder(nn.Module): ...@@ -281,8 +261,7 @@ class BartEncoder(nn.Module):
""" """
# check attention mask and invert # check attention mask and invert
if attention_mask is not None: if attention_mask is not None:
assert attention_mask.dim() == 2 attention_mask = invert_mask(attention_mask)
attention_mask = attention_mask.eq(0)
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_ids) embed_pos = self.embed_positions(input_ids)
...@@ -339,7 +318,13 @@ class DecoderLayer(nn.Module): ...@@ -339,7 +318,13 @@ class DecoderLayer(nn.Module):
self.final_layer_norm = LayerNorm(self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim)
def forward( def forward(
self, x, encoder_hidden_states, encoder_attn_mask=None, layer_state=None, attention_mask=None, self,
x,
encoder_hidden_states,
encoder_attn_mask=None,
layer_state=None,
causal_mask=None,
decoder_padding_mask=None,
): ):
residual = x residual = x
...@@ -347,7 +332,12 @@ class DecoderLayer(nn.Module): ...@@ -347,7 +332,12 @@ class DecoderLayer(nn.Module):
layer_state = {} layer_state = {}
# next line mutates layer state # next line mutates layer state
x, self_attn_weights = self.self_attn( x, self_attn_weights = self.self_attn(
query=x, key=x, layer_state=layer_state, attn_mask=attention_mask, need_weights=self.output_attentions query=x,
key=x,
layer_state=layer_state,
key_padding_mask=decoder_padding_mask,
attn_mask=causal_mask,
need_weights=self.output_attentions,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
...@@ -412,7 +402,8 @@ class BartDecoder(nn.Module): ...@@ -412,7 +402,8 @@ class BartDecoder(nn.Module):
input_ids, input_ids,
encoder_hidden_states, encoder_hidden_states,
encoder_padding_mask, encoder_padding_mask,
combined_mask, decoder_padding_mask,
decoder_causal_mask,
decoder_cached_states=None, decoder_cached_states=None,
generation_mode=False, generation_mode=False,
**unused **unused
...@@ -437,8 +428,7 @@ class BartDecoder(nn.Module): ...@@ -437,8 +428,7 @@ class BartDecoder(nn.Module):
""" """
# check attention mask and invert # check attention mask and invert
if encoder_padding_mask is not None: if encoder_padding_mask is not None:
assert encoder_padding_mask.dim() == 2 encoder_padding_mask = invert_mask(encoder_padding_mask)
encoder_padding_mask = encoder_padding_mask.eq(0)
# embed positions # embed positions
positions = self.embed_positions(input_ids, generation_mode=generation_mode) positions = self.embed_positions(input_ids, generation_mode=generation_mode)
...@@ -458,7 +448,6 @@ class BartDecoder(nn.Module): ...@@ -458,7 +448,6 @@ class BartDecoder(nn.Module):
all_hidden_states = () all_hidden_states = ()
all_self_attns = () all_self_attns = ()
next_decoder_cache = [] next_decoder_cache = []
for i, decoder_layer in enumerate(self.layers): for i, decoder_layer in enumerate(self.layers):
decoder_layer # type: DecoderLayer decoder_layer # type: DecoderLayer
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
...@@ -468,7 +457,12 @@ class BartDecoder(nn.Module): ...@@ -468,7 +457,12 @@ class BartDecoder(nn.Module):
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
x, layer_self_attn, layer_past = decoder_layer( x, layer_self_attn, layer_past = decoder_layer(
x, encoder_hidden_states, encoder_padding_mask, layer_state=layer_state, attention_mask=combined_mask, x,
encoder_hidden_states,
encoder_attn_mask=encoder_padding_mask,
decoder_padding_mask=decoder_padding_mask,
layer_state=layer_state,
causal_mask=decoder_causal_mask,
) )
if self.output_past: if self.output_past:
...@@ -736,6 +730,8 @@ def _filter_out_falsey_values(tup) -> Tuple: ...@@ -736,6 +730,8 @@ def _filter_out_falsey_values(tup) -> Tuple:
# Public API # Public API
def _get_shape(t):
return getattr(t, "shape", None)
@add_start_docstrings( @add_start_docstrings(
...@@ -769,13 +765,16 @@ class BartModel(PretrainedBartModel): ...@@ -769,13 +765,16 @@ class BartModel(PretrainedBartModel):
# make masks if user doesn't supply # make masks if user doesn't supply
if not generation_mode: if not generation_mode:
decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs( decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
self.config, self.config,
input_ids, input_ids,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attn_mask=decoder_attention_mask, decoder_padding_mask=decoder_attention_mask,
mask_dtype=self.shared.weight.dtype, causal_mask_dtype=self.shared.weight.dtype,
) )
else:
decoder_padding_mask, causal_mask = None, None
assert decoder_input_ids is not None assert decoder_input_ids is not None
if encoder_outputs is None: if encoder_outputs is None:
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
...@@ -785,7 +784,8 @@ class BartModel(PretrainedBartModel): ...@@ -785,7 +784,8 @@ class BartModel(PretrainedBartModel):
decoder_input_ids, decoder_input_ids,
encoder_outputs[0], encoder_outputs[0],
attention_mask, attention_mask,
decoder_attention_mask, decoder_padding_mask,
decoder_causal_mask=causal_mask,
decoder_cached_states=decoder_cached_states, decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode, generation_mode=generation_mode,
) )
......
...@@ -36,8 +36,8 @@ if is_torch_available(): ...@@ -36,8 +36,8 @@ if is_torch_available():
from transformers.modeling_bart import ( from transformers.modeling_bart import (
BART_PRETRAINED_MODEL_ARCHIVE_MAP, BART_PRETRAINED_MODEL_ARCHIVE_MAP,
shift_tokens_right, shift_tokens_right,
invert_mask,
_prepare_bart_decoder_inputs, _prepare_bart_decoder_inputs,
LARGE_NEGATIVE,
) )
from transformers.tokenization_bart import BartTokenizer from transformers.tokenization_bart import BartTokenizer
...@@ -123,10 +123,9 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -123,10 +123,9 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
def test_advanced_inputs(self): def test_initialization_more(self):
# (config, input_ids, token_type_ids, input_mask, *unused) = \ # (config, input_ids, token_type_ids, input_mask, *unused) = \
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, inputs_dict["input_ids"])
model = BartModel(config) model = BartModel(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -142,9 +141,17 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -142,9 +141,17 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
_check_var(model.encoder.layers[0].fc1) _check_var(model.encoder.layers[0].fc1)
_check_var(model.encoder.embed_positions) _check_var(model.encoder.embed_positions)
def test_advanced_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict["input_ids"][:, -2:] = config.pad_token_id
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
config, inputs_dict["input_ids"]
)
model = BartModel(config).to(torch_device).eval()
decoder_features_with_created_mask = model(**inputs_dict)[0] decoder_features_with_created_mask = model(**inputs_dict)[0]
decoder_features_with_passed_mask = model( decoder_features_with_passed_mask = model(
decoder_attention_mask=decoder_attn_mask, decoder_input_ids=decoder_input_ids, **inputs_dict decoder_attention_mask=invert_mask(decoder_attn_mask), decoder_input_ids=decoder_input_ids, **inputs_dict
)[0] )[0]
_assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask) _assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask)
useless_mask = torch.zeros_like(decoder_attn_mask) useless_mask = torch.zeros_like(decoder_attn_mask)
...@@ -238,7 +245,7 @@ class BartHeadTests(unittest.TestCase): ...@@ -238,7 +245,7 @@ class BartHeadTests(unittest.TestCase):
lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device) lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
lm_model = BartForConditionalGeneration(config) lm_model = BartForConditionalGeneration(config)
lm_model.to(torch_device) lm_model.to(torch_device)
loss, logits, enc_features = lm_model(input_ids=input_ids, lm_labels=lm_labels, decoder_input_ids=input_ids) loss, logits, enc_features = lm_model(input_ids=input_ids, lm_labels=lm_labels)
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size) expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
self.assertEqual(logits.shape, expected_shape) self.assertEqual(logits.shape, expected_shape)
self.assertIsInstance(loss.item(), float) self.assertIsInstance(loss.item(), float)
...@@ -336,41 +343,23 @@ class BartHeadTests(unittest.TestCase): ...@@ -336,41 +343,23 @@ class BartHeadTests(unittest.TestCase):
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
def test_dummy_inputs(self): def test_dummy_inputs(self):
config, *_ = self._get_config_and_data(output_past=True) config, *_ = self._get_config_and_data()
model = BartForConditionalGeneration(config).eval().to(torch_device) model = BartForConditionalGeneration(config).eval().to(torch_device)
model(**model.dummy_inputs) model(**model.dummy_inputs)
def test_prepare_bart_decoder_inputs(self): def test_prepare_bart_decoder_inputs(self):
config, *_ = self._get_config_and_data(output_past=False) config, *_ = self._get_config_and_data(output_past=False)
input_ids = _long_tensor(([4, 4, 2])) # only used for .device if decoder_input_ids is passed input_ids = _long_tensor(([4, 4, 2]))
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]]) decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
ignore = LARGE_NEGATIVE ignore = float("-inf")
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids) decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
expected_mask = torch.tensor(
[
[0, ignore, ignore],
[0, 0, ignore],
[ignore, ignore, ignore], # never attend to the final token, because its pad
]
).to(input_ids.device)
self.assertEqual(decoder_attn_mask.size(), (1, 1, 3, 3))
self.assertTrue(torch.eq(expected_mask, decoder_attn_mask).all())
# Test no causal mask
config, *_ = self._get_config_and_data(output_past=True)
expected_just_padding_mask = torch.tensor(
[[0, 0, 0], [0, 0, 0], [ignore, ignore, ignore]] # never attend to the final token, because its pad
).to(input_ids.device)
_, decoder_attn_mask_no_causal_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids)
self.assertEqual(decoder_attn_mask_no_causal_mask.size(), (1, 1, 3, 3))
self.assertTrue(torch.eq(expected_just_padding_mask, decoder_attn_mask_no_causal_mask).all())
decoder_input_ids = _long_tensor([[0, 26388, 4133, 2]])
# Attend to everything if no pad tokens and no causal mask
_, decoder_attn_mask_no_padding_no_causal_mask = _prepare_bart_decoder_inputs(
config, input_ids, decoder_input_ids config, input_ids, decoder_input_ids
) )
self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all()) expected_causal_mask = torch.tensor(
[[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad
).to(input_ids.device)
self.assertEqual(decoder_attn_mask.size(), decoder_input_ids.size())
self.assertTrue(torch.eq(expected_causal_mask, causal_mask).all())
def test_resize_tokens_embeddings_more(self): def test_resize_tokens_embeddings_more(self):
config, input_ids, _ = self._get_config_and_data() config, input_ids, _ = self._get_config_and_data()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment