Unverified Commit d3cb2888 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Not use -1e4 as attn mask (#17306)



* Use torch.finfo(self.dtype).min

* for GPTNeoX

* for Albert

* For Splinter

* Update src/transformers/models/data2vec/modeling_data2vec_audio.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix -inf used in Bart-like models

* Fix a few remaining -inf

* more fix

* clean up

* For CLIP

* For FSMT

* clean up

* fix test

* Add dtype argument and use it for LayoutLMv3

* update FlaxLongT5Attention
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent fdb12080
......@@ -805,7 +805,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
if head_mask is not None:
if head_mask.dim() == 1:
......
......@@ -178,7 +178,9 @@ class LayoutLMv2SelfAttention(nn.Module):
attention_scores += rel_pos
if self.has_spatial_attention_bias:
attention_scores += rel_2d_pos
attention_scores = attention_scores.float().masked_fill_(attention_mask.to(torch.bool), float("-inf"))
attention_scores = attention_scores.float().masked_fill_(
attention_mask.to(torch.bool), torch.finfo(attention_scores.dtype).min
)
attention_probs = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).type_as(value_layer)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......@@ -916,7 +918,7 @@ class LayoutLMv2Model(LayoutLMv2PreTrainedModel):
extended_attention_mask = final_attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
if head_mask is not None:
if head_mask.dim() == 1:
......
......@@ -886,7 +886,9 @@ class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
position_ids = position_ids.expand_as(input_ids)
final_position_ids = position_ids
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, None, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, None, device, dtype=embedding_output.dtype
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
......
......@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......@@ -207,7 +207,7 @@ class LEDEncoderSelfAttention(nn.Module):
# cast to fp32/fp16 then replace 1's with -inf
float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
remove_from_windowed_attention_mask, -10000.0
remove_from_windowed_attention_mask, torch.finfo(query_vectors.dtype).min
)
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul(
......@@ -579,7 +579,7 @@ class LEDEncoderSelfAttention(nn.Module):
attn_probs_from_global_key[
is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1]
] = -10000.0
] = torch.finfo(attn_probs_from_global_key.dtype).min
return attn_probs_from_global_key
......@@ -675,11 +675,11 @@ class LEDEncoderSelfAttention(nn.Module):
global_attn_scores[
is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], :
] = -10000.0
] = torch.finfo(global_attn_scores.dtype).min
global_attn_scores = global_attn_scores.masked_fill(
is_index_masked[:, None, None, :],
-10000.0,
torch.finfo(global_attn_scores.dtype).min,
)
global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
......
......@@ -586,7 +586,7 @@ class LongformerSelfAttention(nn.Module):
# cast to fp32/fp16 then replace 1's with -inf
float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
remove_from_windowed_attention_mask, -10000.0
remove_from_windowed_attention_mask, torch.finfo(query_vectors.dtype).min
)
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul(
......@@ -958,7 +958,7 @@ class LongformerSelfAttention(nn.Module):
attn_probs_from_global_key[
is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1]
] = -10000.0
] = torch.finfo(attn_probs_from_global_key.dtype).min
return attn_probs_from_global_key
......@@ -1054,11 +1054,11 @@ class LongformerSelfAttention(nn.Module):
global_attn_scores[
is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], :
] = -10000.0
] = torch.finfo(global_attn_scores.dtype).min
global_attn_scores = global_attn_scores.masked_fill(
is_index_masked[:, None, None, :],
-10000.0,
torch.finfo(global_attn_scores.dtype).min,
)
global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
......
......@@ -549,10 +549,11 @@ class FlaxLongT5Attention(nn.Module):
# replace masked positions with -10_000
if attention_mask is not None:
mask_value = jnp.finfo(self.dtype).min
attention_mask = jax.lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
jnp.full(attention_mask.shape, mask_value).astype(self.dtype),
)
if position_bias is None:
......
......@@ -1076,7 +1076,7 @@ class LukeModel(LukePreTrainedModel):
raise ValueError(f"Wrong shape for attention_mask (shape {attention_mask.shape})")
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
return extended_attention_mask
......
......@@ -959,13 +959,13 @@ class LxmertModel(LxmertPreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
# Process the visual attention mask
if visual_attention_mask is not None:
extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2)
extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=self.dtype)
extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * torch.finfo(self.dtype).min
else:
extended_visual_attention_mask = None
......
......@@ -79,7 +79,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......
......@@ -81,7 +81,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......
......@@ -97,7 +97,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......
......@@ -479,7 +479,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
......
......@@ -61,7 +61,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......
......@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......
......@@ -94,7 +94,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......
......@@ -187,7 +187,9 @@ def ngram_attention_bias(sequence_length, ngram, device, dtype):
"""
This function computes the bias for the predict stream
"""
left_block = torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * float("-inf")
left_block = (
torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * torch.finfo(dtype).min
)
right_block = left_block.detach().clone()
# create bias
for stream_idx in range(ngram):
......@@ -1336,7 +1338,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
if attention_mask is not None:
extended_attention_mask = (
1.0 - attention_mask[:, None, :].repeat(self.config.num_encoder_attention_heads, 1, 1)
) * -10000.0
) * torch.finfo(self.dtype).min
extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype)
else:
extended_attention_mask = None
......@@ -1554,7 +1556,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
if encoder_attention_mask is not None:
extended_encoder_attention_mask = (
1.0 - encoder_attention_mask[:, None, :].repeat(self.config.num_decoder_attention_heads, 1, 1)
) * -10000.0
) * torch.finfo(self.dtype).min
extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype)
else:
extended_encoder_attention_mask = None
......@@ -1713,7 +1715,10 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
# get causal mask
causal_mask = torch.full(
(seq_length, seq_length), -float("inf"), dtype=hidden_states.dtype, device=hidden_states.device
(seq_length, seq_length),
torch.finfo(hidden_states.dtype).min,
dtype=hidden_states.dtype,
device=hidden_states.device,
)
causal_mask = torch.triu(causal_mask, 1)
extended_causal_mask = causal_mask[:seq_length, :seq_length][None, :, :].expand(
......@@ -1722,7 +1727,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
# add usual attention mask
if attention_mask is not None:
extended_attention_mask = (1.0 - attention_mask[:, None, :]) * -10000.0
extended_attention_mask = (1.0 - attention_mask[:, None, :]) * torch.finfo(self.dtype).min
extended_attention_mask = extended_causal_mask + extended_attention_mask
else:
extended_attention_mask = extended_causal_mask
......@@ -1750,7 +1755,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
# add usual attention mask
if attention_mask is not None:
extended_attention_mask = (1.0 - attention_mask[None, :, None, :]) * -10000.0
extended_attention_mask = (1.0 - attention_mask[None, :, None, :]) * torch.finfo(self.dtype).min
extended_attention_mask = extended_attention_mask.expand((self.ngram, batch_size, seq_length, seq_length))
# predicted stream attention_mask should always be 0
extended_attention_mask = torch.cat(
......
......@@ -869,8 +869,8 @@ class RealmReaderProjection(nn.Module):
return starts, ends, span_masks
def mask_to_score(mask):
return (1.0 - mask.type(torch.float32)) * -10000.0
def mask_to_score(mask, dtype=torch.float32):
return (1.0 - mask.type(dtype)) * torch.finfo(dtype).min
# [reader_beam_size, max_sequence_len, span_hidden_size * 2]
hidden_states = self.dense_intermediate(hidden_states)
......@@ -890,7 +890,7 @@ class RealmReaderProjection(nn.Module):
# [reader_beam_size, num_candidates]
reader_logits = self.dense_output(candidate_hidden).squeeze(-1)
# [reader_beam_size, num_candidates]
reader_logits += mask_to_score(candidate_mask)
reader_logits += mask_to_score(candidate_mask, dtype=reader_logits.dtype)
return reader_logits, candidate_starts, candidate_ends
......@@ -1634,11 +1634,11 @@ class RealmReader(RealmPreTrainedModel):
def marginal_log_loss(logits, is_correct):
"""Loss based on the negative marginal log-likelihood."""
def mask_to_score(mask):
return (1.0 - mask.type(torch.float32)) * -10000.0
def mask_to_score(mask, dtype=torch.float32):
return (1.0 - mask.type(dtype)) * torch.finfo(dtype).min
# []
log_numerator = torch.logsumexp(logits + mask_to_score(is_correct), dim=-1)
log_numerator = torch.logsumexp(logits + mask_to_score(is_correct, dtype=logits.dtype), dim=-1)
log_denominator = torch.logsumexp(logits, dim=-1)
return log_denominator - log_numerator
......
......@@ -643,7 +643,8 @@ class SEWEncoder(nn.Module):
attention_mask = (attention_ids < output_lengths.view(-1, 1)).long()
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
......
......@@ -69,7 +69,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......
......@@ -49,7 +49,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment