Unverified Commit d3cb2888 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Not use -1e4 as attn mask (#17306)



* Use torch.finfo(self.dtype).min

* for GPTNeoX

* for Albert

* For Splinter

* Update src/transformers/models/data2vec/modeling_data2vec_audio.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix -inf used in Bart-like models

* Fix a few remaining -inf

* more fix

* clean up

* For CLIP

* For FSMT

* clean up

* fix test

* Add dtype argument and use it for LayoutLMv3

* update FlaxLongT5Attention
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent fdb12080
......@@ -751,15 +751,7 @@ class ModuleUtilsMixin:
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
if self.dtype == torch.float16:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
elif self.dtype in [torch.bfloat16, torch.float32]:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
else:
raise ValueError(
f"{self.dtype} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`"
)
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min
return encoder_extended_attention_mask
......@@ -792,7 +784,7 @@ class ModuleUtilsMixin:
return extended_attention_mask
def get_extended_attention_mask(
self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None
self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None, dtype: torch.float = None
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
......@@ -806,6 +798,9 @@ class ModuleUtilsMixin:
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
if dtype is None:
dtype = self.dtype
if not (attention_mask.dim() == 2 and self.config.is_decoder):
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
if device is not None:
......@@ -836,8 +831,8 @@ class ModuleUtilsMixin:
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
return extended_attention_mask
def get_head_mask(
......
......@@ -728,7 +728,7 @@ class AlbertModel(AlbertPreTrainedModel):
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
......
......@@ -93,7 +93,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......
......@@ -96,7 +96,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......
......@@ -83,7 +83,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......
......@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......
......@@ -467,7 +467,7 @@ class CanineSelfAttention(nn.Module):
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
attention_mask = (1.0 - attention_mask.float()) * -10000.0
attention_mask = (1.0 - attention_mask.float()) * torch.finfo(attention_scores.dtype).min
# Apply the attention mask (precomputed for all layers in CanineModel forward() function)
attention_scores = attention_scores + attention_mask
......
......@@ -638,7 +638,9 @@ class CLIPTextTransformer(nn.Module):
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device)
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......@@ -670,11 +672,11 @@ class CLIPTextTransformer(nn.Module):
attentions=encoder_outputs.attentions,
)
def _build_causal_attention_mask(self, bsz, seq_len):
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len)
mask.fill_(torch.tensor(float("-inf")))
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
......
......@@ -435,7 +435,7 @@ class CTRLModel(CTRLPreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
......
......@@ -578,7 +578,8 @@ class Data2VecAudioEncoder(nn.Module):
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
......
......@@ -188,7 +188,11 @@ class DecisionTransformerGPT2Attention(nn.Module):
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
......@@ -239,7 +243,11 @@ class DecisionTransformerGPT2Attention(nn.Module):
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
......@@ -578,7 +586,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
......
......@@ -1809,7 +1809,7 @@ class DetrMHAttentionMap(nn.Module):
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
if mask is not None:
weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf"))
weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
weights = self.dropout(weights)
return weights
......
......@@ -211,7 +211,9 @@ class MultiHeadSelfAttention(nn.Module):
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(mask, torch.tensor(-float("inf"))) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(
mask, torch.tensor(torch.finfo(scores.dtype).min)
) # (bs, n_heads, q_length, k_length)
weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
......
......@@ -323,8 +323,8 @@ def _prepare_fsmt_decoder_inputs(
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
else:
decoder_padding_mask = invert_mask(decoder_padding_mask)
causal_mask = triu_onnx(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
dtype=causal_mask_dtype, device=decoder_input_ids.device
causal_mask = triu_onnx(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len, dtype=causal_mask_dtype)), 1).to(
device=decoder_input_ids.device
)
return decoder_input_ids, decoder_padding_mask, causal_mask
......@@ -908,7 +908,7 @@ class Attention(nn.Module):
if key_padding_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.masked_fill(reshaped, torch.finfo(attn_weights.dtype).min)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
......@@ -975,7 +975,7 @@ class Attention(nn.Module):
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a input_ids with -inf."""
return t.float().fill_(float("-inf")).type_as(t)
return t.float().fill_(torch.finfo(t.dtype).min).type_as(t)
# Public API
......
......@@ -199,7 +199,11 @@ class GPT2Attention(nn.Module):
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
......@@ -250,7 +254,11 @@ class GPT2Attention(nn.Module):
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
......@@ -811,7 +819,7 @@ class GPT2Model(GPT2PreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
......
......@@ -189,7 +189,11 @@ class GPTNeoSelfAttention(nn.Module):
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
......@@ -566,7 +570,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
......
......@@ -196,7 +196,11 @@ class GPTNeoXAttention(nn.Module):
attn_scores = torch.einsum("bik,bjk->bij", query, key) / self.norm_factor
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
attn_scores = torch.where(causal_mask, attn_scores, self.masked_bias.to(attn_scores.dtype))
mask_value = torch.finfo(attn_scores.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
attn_scores = torch.where(causal_mask, attn_scores, mask_value)
if attention_mask is not None:
# Apply the attention mask
......@@ -214,7 +218,7 @@ class GPTNeoXAttention(nn.Module):
def attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(~ltor_mask, -10000.0)
attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
return attention_scores
......@@ -460,7 +464,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
......
......@@ -170,7 +170,12 @@ class GPTJAttention(nn.Module):
key = key.to(torch.float32)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
attn_weights = attn_weights / self.scale_attn
......@@ -605,7 +610,7 @@ class GPTJModel(GPTJPreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
......
......@@ -664,7 +664,8 @@ class HubertEncoder(nn.Module):
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
......@@ -753,7 +754,8 @@ class HubertEncoderStableLayerNorm(nn.Module):
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
......
......@@ -253,7 +253,11 @@ class ImageGPTAttention(nn.Module):
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
......@@ -304,7 +308,11 @@ class ImageGPTAttention(nn.Module):
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
......@@ -765,7 +773,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment