Unverified Commit d3cb2888 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Not use -1e4 as attn mask (#17306)



* Use torch.finfo(self.dtype).min

* for GPTNeoX

* for Albert

* For Splinter

* Update src/transformers/models/data2vec/modeling_data2vec_audio.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix -inf used in Bart-like models

* Fix a few remaining -inf

* more fix

* clean up

* For CLIP

* For FSMT

* clean up

* fix test

* Add dtype argument and use it for LayoutLMv3

* update FlaxLongT5Attention
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent fdb12080
......@@ -910,8 +910,8 @@ class SplinterForQuestionAnswering(SplinterPreTrainedModel):
start_logits, end_logits = start_logits.squeeze(1), end_logits.squeeze(1)
if attention_mask is not None:
start_logits = start_logits + (1 - attention_mask) * -10000.0
end_logits = end_logits + (1 - attention_mask) * -10000.0
start_logits = start_logits + (1 - attention_mask) * torch.finfo(start_logits.dtype).min
end_logits = end_logits + (1 - attention_mask) * torch.finfo(end_logits.dtype).min
total_loss = None
if start_positions is not None and end_positions is not None:
......@@ -1060,8 +1060,8 @@ class SplinterForPreTraining(SplinterPreTrainedModel):
attention_mask_for_each_question = attention_mask.unsqueeze(1).expand(
batch_size, num_questions, sequence_length
)
start_logits = start_logits + (1 - attention_mask_for_each_question) * -10000.0
end_logits = end_logits + (1 - attention_mask_for_each_question) * -10000.0
start_logits = start_logits + (1 - attention_mask_for_each_question) * torch.finfo(start_logits.dtype).min
end_logits = end_logits + (1 - attention_mask_for_each_question) * torch.finfo(end_logits.dtype).min
total_loss = None
# [batch_size, num_questions, sequence_length]
......
......@@ -409,10 +409,11 @@ class FlaxT5Attention(nn.Module):
# replace masked positions with -10_000
if attention_mask is not None:
mask_value = jnp.finfo(self.dtype).min
attention_mask = jax.lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
jnp.full(attention_mask.shape, mask_value).astype(self.dtype),
)
if position_bias is None:
......
......@@ -333,7 +333,7 @@ class CausalSelfAttention(nn.Module):
# [ batch_size x n_heads x sequence_length x sequence_length ]
attn_weights = (torch.matmul(query, key.transpose(-2, -1))) * (1.0 / math.sqrt(key.size(-1)))
attn_weights = attn_weights.masked_fill(
self.mask[:, :, :sequence_length, :sequence_length] == 0, float("-inf")
self.mask[:, :, :sequence_length, :sequence_length] == 0, torch.finfo(attn_weights.dtype).min
)
attn_weights = F.softmax(attn_weights, dim=-1)
self._attn_map = attn_weights.clone()
......
......@@ -327,21 +327,17 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
attn_score = AC + BD
attn_score.mul_(self.scale)
mask_value = torch.finfo(attn_score.dtype).min
# compute attention probability
if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = attn_mask == 1 # Switch to bool
if attn_mask.dim() == 2:
if next(self.parameters()).dtype == torch.float16:
attn_score = (
attn_score.float().masked_fill(attn_mask[None, :, :, None], -65000).type_as(attn_score)
attn_score.float().masked_fill(attn_mask[None, :, :, None], mask_value).type_as(attn_score)
)
else:
attn_score = attn_score.float().masked_fill(attn_mask[None, :, :, None], -1e30).type_as(attn_score)
elif attn_mask.dim() == 3:
if next(self.parameters()).dtype == torch.float16:
attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -65000).type_as(attn_score)
else:
attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -1e30).type_as(attn_score)
attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], mask_value).type_as(attn_score)
# [qlen x klen x bsz x n_head]
attn_prob = nn.functional.softmax(attn_score, dim=1)
......
......@@ -50,7 +50,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......
......@@ -701,7 +701,8 @@ class UniSpeechEncoder(nn.Module):
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
......@@ -790,7 +791,8 @@ class UniSpeechEncoderStableLayerNorm(nn.Module):
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
......
......@@ -715,7 +715,8 @@ class UniSpeechSatEncoder(nn.Module):
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
......@@ -804,7 +805,8 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module):
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
......
......@@ -1433,7 +1433,7 @@ class VisualBertRegionToPhraseAttention(nn.Module):
def forward(self, query, key, attention_mask):
attention_mask = attention_mask.to(query.dtype)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = (1.0 - attention_mask) * torch.finfo(query.dtype).min
mixed_query_layer = self.query(query)
mixed_key_layer = self.key(key)
......
......@@ -749,7 +749,8 @@ class Wav2Vec2Encoder(nn.Module):
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
......@@ -837,7 +838,8 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
......
......@@ -895,7 +895,8 @@ class Wav2Vec2ConformerEncoder(nn.Module):
hidden_states[~attention_mask] = 0.0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
......
......@@ -120,7 +120,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......
......@@ -181,7 +181,7 @@ class MultiHeadAttention(nn.Module):
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
scores.masked_fill_(mask, -float("inf")) # (bs, n_heads, qlen, klen)
scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)
weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
weights = nn.functional.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
......
......@@ -1632,7 +1632,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......
......@@ -351,9 +351,10 @@ class FSMTHeadTests(unittest.TestCase):
config, *_ = self._get_config_and_data()
input_ids = _long_tensor(([4, 4, 2]))
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
ignore = float("-inf")
causal_mask_dtype = torch.float32
ignore = torch.finfo(causal_mask_dtype).min
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs(
config, input_ids, decoder_input_ids
config, input_ids, decoder_input_ids, causal_mask_dtype=causal_mask_dtype
)
expected_causal_mask = torch.tensor(
[[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment