Unverified Commit d3cb2888 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Not use -1e4 as attn mask (#17306)



* Use torch.finfo(self.dtype).min

* for GPTNeoX

* for Albert

* For Splinter

* Update src/transformers/models/data2vec/modeling_data2vec_audio.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix -inf used in Bart-like models

* Fix a few remaining -inf

* more fix

* clean up

* For CLIP

* For FSMT

* clean up

* fix test

* Add dtype argument and use it for LayoutLMv3

* update FlaxLongT5Attention
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent fdb12080
...@@ -910,8 +910,8 @@ class SplinterForQuestionAnswering(SplinterPreTrainedModel): ...@@ -910,8 +910,8 @@ class SplinterForQuestionAnswering(SplinterPreTrainedModel):
start_logits, end_logits = start_logits.squeeze(1), end_logits.squeeze(1) start_logits, end_logits = start_logits.squeeze(1), end_logits.squeeze(1)
if attention_mask is not None: if attention_mask is not None:
start_logits = start_logits + (1 - attention_mask) * -10000.0 start_logits = start_logits + (1 - attention_mask) * torch.finfo(start_logits.dtype).min
end_logits = end_logits + (1 - attention_mask) * -10000.0 end_logits = end_logits + (1 - attention_mask) * torch.finfo(end_logits.dtype).min
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
...@@ -1060,8 +1060,8 @@ class SplinterForPreTraining(SplinterPreTrainedModel): ...@@ -1060,8 +1060,8 @@ class SplinterForPreTraining(SplinterPreTrainedModel):
attention_mask_for_each_question = attention_mask.unsqueeze(1).expand( attention_mask_for_each_question = attention_mask.unsqueeze(1).expand(
batch_size, num_questions, sequence_length batch_size, num_questions, sequence_length
) )
start_logits = start_logits + (1 - attention_mask_for_each_question) * -10000.0 start_logits = start_logits + (1 - attention_mask_for_each_question) * torch.finfo(start_logits.dtype).min
end_logits = end_logits + (1 - attention_mask_for_each_question) * -10000.0 end_logits = end_logits + (1 - attention_mask_for_each_question) * torch.finfo(end_logits.dtype).min
total_loss = None total_loss = None
# [batch_size, num_questions, sequence_length] # [batch_size, num_questions, sequence_length]
......
...@@ -409,10 +409,11 @@ class FlaxT5Attention(nn.Module): ...@@ -409,10 +409,11 @@ class FlaxT5Attention(nn.Module):
# replace masked positions with -10_000 # replace masked positions with -10_000
if attention_mask is not None: if attention_mask is not None:
mask_value = jnp.finfo(self.dtype).min
attention_mask = jax.lax.select( attention_mask = jax.lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e4).astype(self.dtype), jnp.full(attention_mask.shape, mask_value).astype(self.dtype),
) )
if position_bias is None: if position_bias is None:
......
...@@ -333,7 +333,7 @@ class CausalSelfAttention(nn.Module): ...@@ -333,7 +333,7 @@ class CausalSelfAttention(nn.Module):
# [ batch_size x n_heads x sequence_length x sequence_length ] # [ batch_size x n_heads x sequence_length x sequence_length ]
attn_weights = (torch.matmul(query, key.transpose(-2, -1))) * (1.0 / math.sqrt(key.size(-1))) attn_weights = (torch.matmul(query, key.transpose(-2, -1))) * (1.0 / math.sqrt(key.size(-1)))
attn_weights = attn_weights.masked_fill( attn_weights = attn_weights.masked_fill(
self.mask[:, :, :sequence_length, :sequence_length] == 0, float("-inf") self.mask[:, :, :sequence_length, :sequence_length] == 0, torch.finfo(attn_weights.dtype).min
) )
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = F.softmax(attn_weights, dim=-1)
self._attn_map = attn_weights.clone() self._attn_map = attn_weights.clone()
......
...@@ -327,21 +327,17 @@ class RelPartialLearnableMultiHeadAttn(nn.Module): ...@@ -327,21 +327,17 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
attn_score = AC + BD attn_score = AC + BD
attn_score.mul_(self.scale) attn_score.mul_(self.scale)
mask_value = torch.finfo(attn_score.dtype).min
# compute attention probability # compute attention probability
if attn_mask is not None and torch.sum(attn_mask).item(): if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = attn_mask == 1 # Switch to bool attn_mask = attn_mask == 1 # Switch to bool
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
if next(self.parameters()).dtype == torch.float16: attn_score = (
attn_score = ( attn_score.float().masked_fill(attn_mask[None, :, :, None], mask_value).type_as(attn_score)
attn_score.float().masked_fill(attn_mask[None, :, :, None], -65000).type_as(attn_score) )
)
else:
attn_score = attn_score.float().masked_fill(attn_mask[None, :, :, None], -1e30).type_as(attn_score)
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
if next(self.parameters()).dtype == torch.float16: attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], mask_value).type_as(attn_score)
attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -65000).type_as(attn_score)
else:
attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -1e30).type_as(attn_score)
# [qlen x klen x bsz x n_head] # [qlen x klen x bsz x n_head]
attn_prob = nn.functional.softmax(attn_score, dim=1) attn_prob = nn.functional.softmax(attn_score, dim=1)
......
...@@ -50,7 +50,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -50,7 +50,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
......
...@@ -701,7 +701,8 @@ class UniSpeechEncoder(nn.Module): ...@@ -701,7 +701,8 @@ class UniSpeechEncoder(nn.Module):
hidden_states[~expand_attention_mask] = 0 hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand( attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
) )
...@@ -790,7 +791,8 @@ class UniSpeechEncoderStableLayerNorm(nn.Module): ...@@ -790,7 +791,8 @@ class UniSpeechEncoderStableLayerNorm(nn.Module):
hidden_states[~expand_attention_mask] = 0 hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand( attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
) )
......
...@@ -715,7 +715,8 @@ class UniSpeechSatEncoder(nn.Module): ...@@ -715,7 +715,8 @@ class UniSpeechSatEncoder(nn.Module):
hidden_states[~expand_attention_mask] = 0 hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand( attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
) )
...@@ -804,7 +805,8 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module): ...@@ -804,7 +805,8 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module):
hidden_states[~expand_attention_mask] = 0 hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand( attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
) )
......
...@@ -1433,7 +1433,7 @@ class VisualBertRegionToPhraseAttention(nn.Module): ...@@ -1433,7 +1433,7 @@ class VisualBertRegionToPhraseAttention(nn.Module):
def forward(self, query, key, attention_mask): def forward(self, query, key, attention_mask):
attention_mask = attention_mask.to(query.dtype) attention_mask = attention_mask.to(query.dtype)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (1.0 - attention_mask) * -10000.0 attention_mask = (1.0 - attention_mask) * torch.finfo(query.dtype).min
mixed_query_layer = self.query(query) mixed_query_layer = self.query(query)
mixed_key_layer = self.key(key) mixed_key_layer = self.key(key)
......
...@@ -749,7 +749,8 @@ class Wav2Vec2Encoder(nn.Module): ...@@ -749,7 +749,8 @@ class Wav2Vec2Encoder(nn.Module):
hidden_states[~expand_attention_mask] = 0 hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand( attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
) )
...@@ -837,7 +838,8 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): ...@@ -837,7 +838,8 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
hidden_states[~expand_attention_mask] = 0 hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand( attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
) )
......
...@@ -895,7 +895,8 @@ class Wav2Vec2ConformerEncoder(nn.Module): ...@@ -895,7 +895,8 @@ class Wav2Vec2ConformerEncoder(nn.Module):
hidden_states[~attention_mask] = 0.0 hidden_states[~attention_mask] = 0.0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand( attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
) )
......
...@@ -120,7 +120,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -120,7 +120,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
......
...@@ -181,7 +181,7 @@ class MultiHeadAttention(nn.Module): ...@@ -181,7 +181,7 @@ class MultiHeadAttention(nn.Module):
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head) q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen) mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
scores.masked_fill_(mask, -float("inf")) # (bs, n_heads, qlen, klen) scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)
weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen) weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
weights = nn.functional.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen) weights = nn.functional.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
......
...@@ -1632,7 +1632,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -1632,7 +1632,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
......
...@@ -351,9 +351,10 @@ class FSMTHeadTests(unittest.TestCase): ...@@ -351,9 +351,10 @@ class FSMTHeadTests(unittest.TestCase):
config, *_ = self._get_config_and_data() config, *_ = self._get_config_and_data()
input_ids = _long_tensor(([4, 4, 2])) input_ids = _long_tensor(([4, 4, 2]))
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]]) decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
ignore = float("-inf") causal_mask_dtype = torch.float32
ignore = torch.finfo(causal_mask_dtype).min
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs( decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs(
config, input_ids, decoder_input_ids config, input_ids, decoder_input_ids, causal_mask_dtype=causal_mask_dtype
) )
expected_causal_mask = torch.tensor( expected_causal_mask = torch.tensor(
[[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad [[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment