"...resnet50_tensorflow.git" did not exist on "2d1e1782ff07077b73b03c66d201770fa196809f"
Unverified Commit 49204c1d authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Better SDPA unmasking implementation (#29318)

* better unmask imple

* comment

* typo

* bug report pytorch

* cleanup

* fix import

* add back example

* retrigger ci

* come on
parent f54d82ca
...@@ -187,7 +187,8 @@ class AttentionMaskConverter: ...@@ -187,7 +187,8 @@ class AttentionMaskConverter:
@staticmethod @staticmethod
def _unmask_unattended( def _unmask_unattended(
expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float] expanded_mask: torch.FloatTensor,
min_dtype: float,
): ):
# fmt: off # fmt: off
""" """
...@@ -200,13 +201,7 @@ class AttentionMaskConverter: ...@@ -200,13 +201,7 @@ class AttentionMaskConverter:
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias. The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
For example, if `attention_mask` is For example, if `expanded_mask` is (e.g. here left-padding case)
```
[[0, 0, 1],
[1, 1, 1],
[0, 1, 1]]
```
and `expanded_mask` is (e.g. here left-padding case)
``` ```
[[[[0, 0, 0], [[[[0, 0, 0],
[0, 0, 0], [0, 0, 0],
...@@ -232,47 +227,12 @@ class AttentionMaskConverter: ...@@ -232,47 +227,12 @@ class AttentionMaskConverter:
``` ```
""" """
# fmt: on # fmt: on
if expanded_mask.dtype == torch.bool:
raise ValueError(
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
)
# Get the index of the first non-zero value for every sample in the batch. return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
# In the above example, indices = [[2], [0], [1]]]
tmp = torch.arange(attention_mask.shape[1], 0, -1)
indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)
# Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
# expanded mask will be completely unattended.
left_masked_rows = torch.where(indices > 0)[0]
if left_masked_rows.shape[0] == 0:
return expanded_mask
indices = indices[left_masked_rows]
max_len = torch.max(indices)
range_tensor = torch.arange(max_len).unsqueeze(0)
range_tensor = range_tensor.repeat(indices.size(0), 1)
# Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above.
range_tensor[range_tensor >= indices] = 0
# TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case
if expanded_mask.dim() == 4:
num_masks = expanded_mask.shape[1]
if num_masks == 1:
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
mask_slice = (left_masked_rows[:, None], 0, range_tensor)
else:
# Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len]
mask_slice = (
left_masked_rows[:, None, None],
torch.arange(num_masks)[None, :, None],
range_tensor[:, None, :],
)
else:
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
mask_slice = (left_masked_rows[:, None], range_tensor)
expanded_mask[mask_slice] = unmasked_value
return expanded_mask
def _prepare_4d_causal_attention_mask( def _prepare_4d_causal_attention_mask(
...@@ -406,15 +366,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa( ...@@ -406,15 +366,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
key_value_length=key_value_length, key_value_length=key_value_length,
) )
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# # Details: https://github.com/pytorch/pytorch/issues/110213
# This fix is not applied in case we are tracing with torch.jit.trace or symbolic_trace, as _unmask_unattended has a data-dependent if not is_tracing and expanded_4d_mask.device.type == "cuda":
# controlflow that can not be captured properly.
# TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case.
if query_length > 1 and not is_tracing:
expanded_4d_mask = AttentionMaskConverter._unmask_unattended( expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
expanded_4d_mask, attention_mask, unmasked_value=0.0 expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
) )
return expanded_4d_mask return expanded_4d_mask
......
...@@ -438,9 +438,9 @@ class FalconAttention(nn.Module): ...@@ -438,9 +438,9 @@ class FalconAttention(nn.Module):
else: else:
present = None present = None
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, if self._use_sdpa and query_layer.device.type == "cuda" and attention_mask is not None:
# Reference: https://github.com/pytorch/pytorch/issues/112577. # For torch<=2.1.2, SDPA with memory-efficient backend is bugged with non-contiguous inputs with custom attn_mask,
if query_layer.device.type == "cuda" and attention_mask is not None: # Reference: https://github.com/pytorch/pytorch/issues/112577.
query_layer = query_layer.contiguous() query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous() key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous() value_layer = value_layer.contiguous()
...@@ -456,6 +456,7 @@ class FalconAttention(nn.Module): ...@@ -456,6 +456,7 @@ class FalconAttention(nn.Module):
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1.
is_causal=self.is_causal and attention_mask is None and query_length > 1, is_causal=self.is_causal and attention_mask is None and query_length > 1,
) )
attention_scores = None attention_scores = None
else: else:
attention_scores = query_layer @ key_layer.transpose(-1, -2) attention_scores = query_layer @ key_layer.transpose(-1, -2)
...@@ -1112,18 +1113,17 @@ class FalconModel(FalconPreTrainedModel): ...@@ -1112,18 +1113,17 @@ class FalconModel(FalconPreTrainedModel):
if attention_mask_2d is None: if attention_mask_2d is None:
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads) attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
else: else:
min_dtype = torch.finfo(alibi.dtype).min
attention_mask = torch.masked_fill( attention_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads), alibi / math.sqrt(self.config.hidden_size // self.num_heads),
attention_mask < -1, attention_mask < -1,
torch.finfo(alibi.dtype).min, min_dtype,
) )
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if seq_length > 1: if seq_length > 1 and attention_mask.device.type == "cuda":
attention_mask = AttentionMaskConverter._unmask_unattended( attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
attention_mask, attention_mask_2d, unmasked_value=0.0
)
else: else:
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case. # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
......
...@@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import ( from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask,
) )
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
...@@ -978,7 +979,11 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -978,7 +979,11 @@ class GemmaModel(GemmaPreTrainedModel):
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
if self.config._attn_implementation == "sdpa" and attention_mask is not None: if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
):
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = ( is_tracing = (
torch.jit.is_tracing() torch.jit.is_tracing()
...@@ -986,10 +991,10 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -986,10 +991,10 @@ class GemmaModel(GemmaPreTrainedModel):
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
) )
if not is_tracing and torch.any(attention_mask != 1): if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213 # Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype) causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask return causal_mask
......
...@@ -30,6 +30,7 @@ from ...modeling_outputs import ( ...@@ -30,6 +30,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_2
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -534,21 +535,16 @@ class GPTBigCodeSdpaAttention(GPTBigCodeAttention): ...@@ -534,21 +535,16 @@ class GPTBigCodeSdpaAttention(GPTBigCodeAttention):
key = key.unsqueeze(1) key = key.unsqueeze(1)
value = value.unsqueeze(1) value = value.unsqueeze(1)
# Although these expand are not numerically useful, PyTorch 2.1 can not dispatch to memory-efficient backend # Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend
# and flash attention backend (No available kernel. Aborting execution.) from the shapes # and flash attention backend (No available kernel. Aborting execution.) from the shapes
# query = [batch_size, num_heads, query_length, head_dim] # query = [batch_size, num_heads, query_length, head_dim]
# key = [batch_size, 1, past_length, head_dim] # key = [batch_size, 1, past_length, head_dim]
# value = [batch_size, 1, past_length, head_dim] # value = [batch_size, 1, past_length, head_dim]
# #
# so we could do: # torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check.
# if is_torch_greater_or_equal_than_2_2:
# key = key.expand(-1, self.num_heads, -1, -1) key = key.expand(-1, self.num_heads, -1, -1)
# value = value.expand(-1, self.num_heads, -1, -1) value = value.expand(-1, self.num_heads, -1, -1)
#
# However SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# so we always dispatch to the math path: https://github.com/pytorch/pytorch/issues/112577.
# Arguably we could still do expand + contiguous when `query.device.type == "cuda"` in order to dispatch on memory-efficient
# backend, but it feels very hacky.
else: else:
query_length = query_shape[-1] query_length = query_shape[-1]
...@@ -1020,6 +1016,15 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): ...@@ -1020,6 +1016,15 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
if self._use_sdpa and head_mask is None and not output_attentions: if self._use_sdpa and head_mask is None and not output_attentions:
# SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer.
dtype = self.wte.weight.dtype
min_dtype = torch.finfo(dtype).min
self_attention_mask = torch.where(
self_attention_mask,
torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device),
torch.full([], min_dtype, dtype=dtype, device=self_attention_mask.device),
)
# output_attentions=True can not be supported when using SDPA, and we fall back on # output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases. # the manual implementation that requires a 4D causal mask in all cases.
if self.multi_query: if self.multi_query:
...@@ -1027,23 +1032,13 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): ...@@ -1027,23 +1032,13 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
# [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose. # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose.
self_attention_mask = self_attention_mask.transpose(1, 2) self_attention_mask = self_attention_mask.transpose(1, 2)
if query_length > 1 and attention_mask is not None: if query_length > 1 and attention_mask is not None and attention_mask.device.type == "cuda":
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
self_attention_mask = AttentionMaskConverter._unmask_unattended( self_attention_mask = AttentionMaskConverter._unmask_unattended(
self_attention_mask, attention_mask, unmasked_value=True self_attention_mask, min_dtype=min_dtype
) )
# SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer.
dtype = self.wte.weight.dtype
self_attention_mask = torch.where(
self_attention_mask,
torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device),
torch.full(
[], torch.finfo(self.wte.weight.dtype).min, dtype=dtype, device=self_attention_mask.device
),
)
attention_mask = self_attention_mask attention_mask = self_attention_mask
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
......
...@@ -30,6 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -30,6 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
...@@ -1090,7 +1091,11 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1090,7 +1091,11 @@ class LlamaModel(LlamaPreTrainedModel):
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
if self.config._attn_implementation == "sdpa" and attention_mask is not None: if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
):
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = ( is_tracing = (
torch.jit.is_tracing() torch.jit.is_tracing()
...@@ -1098,10 +1103,10 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1098,10 +1103,10 @@ class LlamaModel(LlamaPreTrainedModel):
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
) )
if not is_tracing and torch.any(attention_mask != 1): if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213 # Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype) causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask return causal_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment