Unverified Commit d3cb2888 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Not use -1e4 as attn mask (#17306)



* Use torch.finfo(self.dtype).min

* for GPTNeoX

* for Albert

* For Splinter

* Update src/transformers/models/data2vec/modeling_data2vec_audio.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix -inf used in Bart-like models

* Fix a few remaining -inf

* more fix

* clean up

* For CLIP

* For FSMT

* clean up

* fix test

* Add dtype argument and use it for LayoutLMv3

* update FlaxLongT5Attention
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent fdb12080
...@@ -751,15 +751,7 @@ class ModuleUtilsMixin: ...@@ -751,15 +751,7 @@ class ModuleUtilsMixin:
# encoder_extended_attention_mask = (encoder_extended_attention_mask == # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2)) # encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min
if self.dtype == torch.float16:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
elif self.dtype in [torch.bfloat16, torch.float32]:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
else:
raise ValueError(
f"{self.dtype} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`"
)
return encoder_extended_attention_mask return encoder_extended_attention_mask
...@@ -792,7 +784,7 @@ class ModuleUtilsMixin: ...@@ -792,7 +784,7 @@ class ModuleUtilsMixin:
return extended_attention_mask return extended_attention_mask
def get_extended_attention_mask( def get_extended_attention_mask(
self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None, dtype: torch.float = None
) -> Tensor: ) -> Tensor:
""" """
Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
...@@ -806,6 +798,9 @@ class ModuleUtilsMixin: ...@@ -806,6 +798,9 @@ class ModuleUtilsMixin:
Returns: Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
""" """
if dtype is None:
dtype = self.dtype
if not (attention_mask.dim() == 2 and self.config.is_decoder): if not (attention_mask.dim() == 2 and self.config.is_decoder):
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
if device is not None: if device is not None:
...@@ -836,8 +831,8 @@ class ModuleUtilsMixin: ...@@ -836,8 +831,8 @@ class ModuleUtilsMixin:
# positions we want to attend and -10000.0 for masked positions. # positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
return extended_attention_mask return extended_attention_mask
def get_head_mask( def get_head_mask(
......
...@@ -728,7 +728,7 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -728,7 +728,7 @@ class AlbertModel(AlbertPreTrainedModel):
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings( embedding_output = self.embeddings(
......
...@@ -93,7 +93,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -93,7 +93,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
......
...@@ -96,7 +96,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -96,7 +96,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
......
...@@ -83,7 +83,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -83,7 +83,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
......
...@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
......
...@@ -467,7 +467,7 @@ class CanineSelfAttention(nn.Module): ...@@ -467,7 +467,7 @@ class CanineSelfAttention(nn.Module):
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions. # positions we want to attend and -10000.0 for masked positions.
attention_mask = (1.0 - attention_mask.float()) * -10000.0 attention_mask = (1.0 - attention_mask.float()) * torch.finfo(attention_scores.dtype).min
# Apply the attention mask (precomputed for all layers in CanineModel forward() function) # Apply the attention mask (precomputed for all layers in CanineModel forward() function)
attention_scores = attention_scores + attention_mask attention_scores = attention_scores + attention_mask
......
...@@ -638,7 +638,9 @@ class CLIPTextTransformer(nn.Module): ...@@ -638,7 +638,9 @@ class CLIPTextTransformer(nn.Module):
bsz, seq_len = input_shape bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here. # CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device) causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
...@@ -670,11 +672,11 @@ class CLIPTextTransformer(nn.Module): ...@@ -670,11 +672,11 @@ class CLIPTextTransformer(nn.Module):
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
) )
def _build_causal_attention_mask(self, bsz, seq_len): def _build_causal_attention_mask(self, bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens # lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf # pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len) mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(float("-inf"))) mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask mask = mask.unsqueeze(1) # expand mask
return mask return mask
......
...@@ -435,7 +435,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -435,7 +435,7 @@ class CTRLModel(CTRLPreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0 attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed # Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.n_layer) head_mask = self.get_head_mask(head_mask, self.config.n_layer)
......
...@@ -578,7 +578,8 @@ class Data2VecAudioEncoder(nn.Module): ...@@ -578,7 +578,8 @@ class Data2VecAudioEncoder(nn.Module):
hidden_states[~expand_attention_mask] = 0 hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand( attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
) )
......
...@@ -188,7 +188,11 @@ class DecisionTransformerGPT2Attention(nn.Module): ...@@ -188,7 +188,11 @@ class DecisionTransformerGPT2Attention(nn.Module):
# if only "normal" attention layer implements causal mask # if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
...@@ -239,7 +243,11 @@ class DecisionTransformerGPT2Attention(nn.Module): ...@@ -239,7 +243,11 @@ class DecisionTransformerGPT2Attention(nn.Module):
# if only "normal" attention layer implements causal mask # if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
...@@ -578,7 +586,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): ...@@ -578,7 +586,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0 attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
......
...@@ -1809,7 +1809,7 @@ class DetrMHAttentionMap(nn.Module): ...@@ -1809,7 +1809,7 @@ class DetrMHAttentionMap(nn.Module):
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head) weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
if mask is not None: if mask is not None:
weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size()) weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
weights = self.dropout(weights) weights = self.dropout(weights)
return weights return weights
......
...@@ -211,7 +211,9 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -211,7 +211,9 @@ class MultiHeadSelfAttention(nn.Module):
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(mask, torch.tensor(-float("inf"))) # (bs, n_heads, q_length, k_length) scores = scores.masked_fill(
mask, torch.tensor(torch.finfo(scores.dtype).min)
) # (bs, n_heads, q_length, k_length)
weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length) weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length) weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
......
...@@ -323,8 +323,8 @@ def _prepare_fsmt_decoder_inputs( ...@@ -323,8 +323,8 @@ def _prepare_fsmt_decoder_inputs(
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id) decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
else: else:
decoder_padding_mask = invert_mask(decoder_padding_mask) decoder_padding_mask = invert_mask(decoder_padding_mask)
causal_mask = triu_onnx(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to( causal_mask = triu_onnx(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len, dtype=causal_mask_dtype)), 1).to(
dtype=causal_mask_dtype, device=decoder_input_ids.device device=decoder_input_ids.device
) )
return decoder_input_ids, decoder_padding_mask, causal_mask return decoder_input_ids, decoder_padding_mask, causal_mask
...@@ -908,7 +908,7 @@ class Attention(nn.Module): ...@@ -908,7 +908,7 @@ class Attention(nn.Module):
if key_padding_mask is not None: # don't attend to padding symbols if key_padding_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2) reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) attn_weights = attn_weights.masked_fill(reshaped, torch.finfo(attn_weights.dtype).min)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
...@@ -975,7 +975,7 @@ class Attention(nn.Module): ...@@ -975,7 +975,7 @@ class Attention(nn.Module):
def fill_with_neg_inf(t): def fill_with_neg_inf(t):
"""FP16-compatible function that fills a input_ids with -inf.""" """FP16-compatible function that fills a input_ids with -inf."""
return t.float().fill_(float("-inf")).type_as(t) return t.float().fill_(torch.finfo(t.dtype).min).type_as(t)
# Public API # Public API
......
...@@ -199,7 +199,11 @@ class GPT2Attention(nn.Module): ...@@ -199,7 +199,11 @@ class GPT2Attention(nn.Module):
# if only "normal" attention layer implements causal mask # if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
...@@ -250,7 +254,11 @@ class GPT2Attention(nn.Module): ...@@ -250,7 +254,11 @@ class GPT2Attention(nn.Module):
# if only "normal" attention layer implements causal mask # if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
...@@ -811,7 +819,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -811,7 +819,7 @@ class GPT2Model(GPT2PreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0 attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
......
...@@ -189,7 +189,11 @@ class GPTNeoSelfAttention(nn.Module): ...@@ -189,7 +189,11 @@ class GPTNeoSelfAttention(nn.Module):
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
...@@ -566,7 +570,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): ...@@ -566,7 +570,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0 attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
......
...@@ -196,7 +196,11 @@ class GPTNeoXAttention(nn.Module): ...@@ -196,7 +196,11 @@ class GPTNeoXAttention(nn.Module):
attn_scores = torch.einsum("bik,bjk->bij", query, key) / self.norm_factor attn_scores = torch.einsum("bik,bjk->bij", query, key) / self.norm_factor
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
attn_scores = torch.where(causal_mask, attn_scores, self.masked_bias.to(attn_scores.dtype)) mask_value = torch.finfo(attn_scores.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
attn_scores = torch.where(causal_mask, attn_scores, mask_value)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
...@@ -214,7 +218,7 @@ class GPTNeoXAttention(nn.Module): ...@@ -214,7 +218,7 @@ class GPTNeoXAttention(nn.Module):
def attention_mask_func(attention_scores, ltor_mask): def attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(~ltor_mask, -10000.0) attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
return attention_scores return attention_scores
...@@ -460,7 +464,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): ...@@ -460,7 +464,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0 attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
......
...@@ -170,7 +170,12 @@ class GPTJAttention(nn.Module): ...@@ -170,7 +170,12 @@ class GPTJAttention(nn.Module):
key = key.to(torch.float32) key = key.to(torch.float32)
attn_weights = torch.matmul(query, key.transpose(-1, -2)) attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
attn_weights = attn_weights / self.scale_attn attn_weights = attn_weights / self.scale_attn
...@@ -605,7 +610,7 @@ class GPTJModel(GPTJPreTrainedModel): ...@@ -605,7 +610,7 @@ class GPTJModel(GPTJPreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0 attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
......
...@@ -664,7 +664,8 @@ class HubertEncoder(nn.Module): ...@@ -664,7 +664,8 @@ class HubertEncoder(nn.Module):
hidden_states[~expand_attention_mask] = 0 hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand( attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
) )
...@@ -753,7 +754,8 @@ class HubertEncoderStableLayerNorm(nn.Module): ...@@ -753,7 +754,8 @@ class HubertEncoderStableLayerNorm(nn.Module):
hidden_states[~expand_attention_mask] = 0 hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand( attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
) )
......
...@@ -253,7 +253,11 @@ class ImageGPTAttention(nn.Module): ...@@ -253,7 +253,11 @@ class ImageGPTAttention(nn.Module):
# if only "normal" attention layer implements causal mask # if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
...@@ -304,7 +308,11 @@ class ImageGPTAttention(nn.Module): ...@@ -304,7 +308,11 @@ class ImageGPTAttention(nn.Module):
# if only "normal" attention layer implements causal mask # if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
...@@ -765,7 +773,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel): ...@@ -765,7 +773,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0 attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment