Unverified Commit 37a12c4e authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix flash attention (#84)



* ignore self attention mask for causal type
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* further relax checks to run FA, update docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix pytorch softmax path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* minimum ampere requirement for fa
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7d6c1d02
......@@ -792,6 +792,7 @@ def test_export_core_attention(
if attn_mask_type is None:
attn_mask_type = 'causal'
inp = (query_layer, key_layer, value_layer)
model = te.transformer.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
......
......@@ -16,6 +16,15 @@ THREADS_PER_WARP = 32
THREADS_PER_BLOCK = 128
_default_causal_mask = {}
def _get_default_causal_mask(sq: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input"""
if sq not in _default_causal_mask:
_default_causal_mask[sq] = torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
return _default_causal_mask[sq]
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
......@@ -274,6 +283,10 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.scale is not None:
inp = inp * self.scale
if self.attn_mask_type == "causal":
mask = _get_default_causal_mask(inp.size()[2])
mask_output = self.mask_func(inp, mask) if mask is not None else inp
probs = torch.nn.Softmax(dim=-1)(mask_output)
......
......@@ -27,6 +27,7 @@ from transformer_engine.pytorch.utils import (
split_tensor_along_dim,
cast_if_needed,
get_default_init_method,
get_device_compute_capability,
)
from transformer_engine.pytorch.constants import (
AttnMaskTypes,
......@@ -220,9 +221,6 @@ class FlashAttention(torch.nn.Module):
assert (
attn_mask_type == "causal"
), 'FlashAttention currently only supports causal attention mask.'
assert (
attention_softmax_in_fp32
), 'FlashAttention currently only supports softmax compute in fp32.'
self.attn_causal_mask = attn_mask_type == "causal"
self.norm_factor = norm_factor
......@@ -230,6 +228,7 @@ class FlashAttention(torch.nn.Module):
self.attention_dropout = attention_dropout
self.layer_number = layer_number
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
def forward(
self,
......@@ -287,6 +286,11 @@ class DotProductAttention(torch.nn.Module):
representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
.. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`attn_mask_type` is set to `"causal"`.
.. warning::
For the default attention mechanism, this module executes a non-deterministic version of
......@@ -303,15 +307,6 @@ class DotProductAttention(torch.nn.Module):
number of key-value channels.
attention_dropout: float, default = 0.0
dropout probability for the dropout op during multi-head attention.
layer_number: int, default = `None`
layer number of the current `DotProductAttention` when multiple such modules
are concatenated, for instance in consecutive transformer blocks.
apply_query_key_layer_scaling: bool, default = `False`
apply query-key layer scaling during BMM1
by a factor of `layer_number`
attention_softmax_in_fp32: bool, default = `True`
if set to `False`, softmax is executed in
the dtype of activation tensors.
attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation.
......@@ -371,9 +366,8 @@ class DotProductAttention(torch.nn.Module):
self.use_flash_attention = (
int(os.getenv("NVTE_FLASH_ATTN", "1"))
and attention_softmax_in_fp32
and attn_mask_type == "causal"
and not apply_query_key_layer_scaling
and get_device_compute_capability() >= 8.0
)
attn_kwargs = {
......@@ -422,6 +416,11 @@ class DotProductAttention(torch.nn.Module):
"""
Dot Product Attention Layer.
.. note::
Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
is set to `"causal"`.
.. note::
Input tensors :attr:`query_layer`, :attr:`key_layer`, and :attr:`value_layer`
......@@ -448,8 +447,7 @@ class DotProductAttention(torch.nn.Module):
"""
use_flash_attention = self.use_flash_attention
if (attention_mask is not None
or query_layer.dtype not in [torch.bfloat16, torch.float16]
if (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
):
......@@ -515,6 +513,7 @@ class MultiHeadAttention(torch.nn.Module):
self.return_layernorm_output = return_layernorm_output
self.params_dtype = params_dtype
self.init_method = init_method
self.attn_mask_type = attn_mask_type
if not fuse_qkv_params:
qkv_weight_interleaved = False
......@@ -658,7 +657,7 @@ class MultiHeadAttention(torch.nn.Module):
"""MultiHeadAttention FWD"""
# hidden_states: [sq, b, h]
if attention_mask is not None:
if self.attn_mask_type != "causal" and attention_mask is not None:
assert (
attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor"
......@@ -836,6 +835,11 @@ class TransformerLayer(torch.nn.Module):
TransformerLayer is made up of an attention block and a feedforward network (MLP).
This standard layer is based on the paper "Attention Is All You Need".
.. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`self_attn_mask_type` is set to `"causal"`.
Parameters
----------
hidden_size : int
......@@ -983,6 +987,7 @@ class TransformerLayer(torch.nn.Module):
self.apply_residual_connection_post_layernorm = (
apply_residual_connection_post_layernorm
)
self.self_attn_mask_type = self_attn_mask_type
assert (
self_attn_mask_type in AttnMaskTypes
), f"self_attn_mask_type {self_attn_mask_type} not supported"
......@@ -1129,6 +1134,11 @@ class TransformerLayer(torch.nn.Module):
"""
Transformer Layer: attention block and a feedforward network (MLP)
.. note::
Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type`
is set to `"causal"`.
Parameters
----------
hidden_states : torch.Tensor
......@@ -1163,7 +1173,7 @@ class TransformerLayer(torch.nn.Module):
hidden_states = hidden_states.contiguous()
if attention_mask is not None:
if self.self_attn_mask_type != "causal" and attention_mask is not None:
assert (
attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor"
......
......@@ -8,6 +8,13 @@ from typing import Any, Callable, Optional, Tuple
import torch
def get_device_compute_capability() -> float:
"""Returns the cuda compute capability of current GPU"""
major = torch.cuda.get_device_properties(torch.cuda.current_device()).major
minor = torch.cuda.get_device_properties(torch.cuda.current_device()).minor
return major + minor / 10
def attention_mask_func(
attention_scores: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment