Unverified Commit 49204c1d authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Better SDPA unmasking implementation (#29318)

* better unmask imple

* comment

* typo

* bug report pytorch

* cleanup

* fix import

* add back example

* retrigger ci

* come on
parent f54d82ca
......@@ -187,7 +187,8 @@ class AttentionMaskConverter:
@staticmethod
def _unmask_unattended(
expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]
expanded_mask: torch.FloatTensor,
min_dtype: float,
):
# fmt: off
"""
......@@ -200,13 +201,7 @@ class AttentionMaskConverter:
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
For example, if `attention_mask` is
```
[[0, 0, 1],
[1, 1, 1],
[0, 1, 1]]
```
and `expanded_mask` is (e.g. here left-padding case)
For example, if `expanded_mask` is (e.g. here left-padding case)
```
[[[[0, 0, 0],
[0, 0, 0],
......@@ -232,47 +227,12 @@ class AttentionMaskConverter:
```
"""
# fmt: on
# Get the index of the first non-zero value for every sample in the batch.
# In the above example, indices = [[2], [0], [1]]]
tmp = torch.arange(attention_mask.shape[1], 0, -1)
indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)
# Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
# expanded mask will be completely unattended.
left_masked_rows = torch.where(indices > 0)[0]
if left_masked_rows.shape[0] == 0:
return expanded_mask
indices = indices[left_masked_rows]
max_len = torch.max(indices)
range_tensor = torch.arange(max_len).unsqueeze(0)
range_tensor = range_tensor.repeat(indices.size(0), 1)
# Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above.
range_tensor[range_tensor >= indices] = 0
# TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case
if expanded_mask.dim() == 4:
num_masks = expanded_mask.shape[1]
if num_masks == 1:
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
mask_slice = (left_masked_rows[:, None], 0, range_tensor)
else:
# Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len]
mask_slice = (
left_masked_rows[:, None, None],
torch.arange(num_masks)[None, :, None],
range_tensor[:, None, :],
if expanded_mask.dtype == torch.bool:
raise ValueError(
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
)
else:
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
mask_slice = (left_masked_rows[:, None], range_tensor)
expanded_mask[mask_slice] = unmasked_value
return expanded_mask
return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
def _prepare_4d_causal_attention_mask(
......@@ -406,15 +366,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
key_value_length=key_value_length,
)
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
#
# This fix is not applied in case we are tracing with torch.jit.trace or symbolic_trace, as _unmask_unattended has a data-dependent
# controlflow that can not be captured properly.
# TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case.
if query_length > 1 and not is_tracing:
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
if not is_tracing and expanded_4d_mask.device.type == "cuda":
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
expanded_4d_mask, attention_mask, unmasked_value=0.0
expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
)
return expanded_4d_mask
......
......@@ -438,9 +438,9 @@ class FalconAttention(nn.Module):
else:
present = None
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
if self._use_sdpa and query_layer.device.type == "cuda" and attention_mask is not None:
# For torch<=2.1.2, SDPA with memory-efficient backend is bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
......@@ -456,6 +456,7 @@ class FalconAttention(nn.Module):
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1.
is_causal=self.is_causal and attention_mask is None and query_length > 1,
)
attention_scores = None
else:
attention_scores = query_layer @ key_layer.transpose(-1, -2)
......@@ -1112,18 +1113,17 @@ class FalconModel(FalconPreTrainedModel):
if attention_mask_2d is None:
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
else:
min_dtype = torch.finfo(alibi.dtype).min
attention_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
attention_mask < -1,
torch.finfo(alibi.dtype).min,
min_dtype,
)
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if seq_length > 1:
attention_mask = AttentionMaskConverter._unmask_unattended(
attention_mask, attention_mask_2d, unmasked_value=0.0
)
if seq_length > 1 and attention_mask.device.type == "cuda":
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
else:
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
attention_mask = _prepare_4d_causal_attention_mask(
......
......@@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
)
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
......@@ -978,7 +979,11 @@ class GemmaModel(GemmaPreTrainedModel):
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
if self.config._attn_implementation == "sdpa" and attention_mask is not None:
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
):
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
......@@ -986,10 +991,10 @@ class GemmaModel(GemmaPreTrainedModel):
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
......
......@@ -30,6 +30,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_2
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -534,21 +535,16 @@ class GPTBigCodeSdpaAttention(GPTBigCodeAttention):
key = key.unsqueeze(1)
value = value.unsqueeze(1)
# Although these expand are not numerically useful, PyTorch 2.1 can not dispatch to memory-efficient backend
# Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend
# and flash attention backend (No available kernel. Aborting execution.) from the shapes
# query = [batch_size, num_heads, query_length, head_dim]
# key = [batch_size, 1, past_length, head_dim]
# value = [batch_size, 1, past_length, head_dim]
#
# so we could do:
#
# key = key.expand(-1, self.num_heads, -1, -1)
# value = value.expand(-1, self.num_heads, -1, -1)
#
# However SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# so we always dispatch to the math path: https://github.com/pytorch/pytorch/issues/112577.
# Arguably we could still do expand + contiguous when `query.device.type == "cuda"` in order to dispatch on memory-efficient
# backend, but it feels very hacky.
# torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check.
if is_torch_greater_or_equal_than_2_2:
key = key.expand(-1, self.num_heads, -1, -1)
value = value.expand(-1, self.num_heads, -1, -1)
else:
query_length = query_shape[-1]
......@@ -1020,6 +1016,15 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
if self._use_sdpa and head_mask is None and not output_attentions:
# SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer.
dtype = self.wte.weight.dtype
min_dtype = torch.finfo(dtype).min
self_attention_mask = torch.where(
self_attention_mask,
torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device),
torch.full([], min_dtype, dtype=dtype, device=self_attention_mask.device),
)
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
if self.multi_query:
......@@ -1027,21 +1032,11 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
# [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose.
self_attention_mask = self_attention_mask.transpose(1, 2)
if query_length > 1 and attention_mask is not None:
if query_length > 1 and attention_mask is not None and attention_mask.device.type == "cuda":
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
self_attention_mask = AttentionMaskConverter._unmask_unattended(
self_attention_mask, attention_mask, unmasked_value=True
)
# SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer.
dtype = self.wte.weight.dtype
self_attention_mask = torch.where(
self_attention_mask,
torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device),
torch.full(
[], torch.finfo(self.wte.weight.dtype).min, dtype=dtype, device=self_attention_mask.device
),
self_attention_mask, min_dtype=min_dtype
)
attention_mask = self_attention_mask
......
......@@ -30,6 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
......@@ -1090,7 +1091,11 @@ class LlamaModel(LlamaPreTrainedModel):
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
if self.config._attn_implementation == "sdpa" and attention_mask is not None:
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
):
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
......@@ -1098,10 +1103,10 @@ class LlamaModel(LlamaPreTrainedModel):
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment