Unverified Commit 37a12c4e authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix flash attention (#84)



* ignore self attention mask for causal type
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* further relax checks to run FA, update docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix pytorch softmax path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* minimum ampere requirement for fa
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7d6c1d02
...@@ -792,6 +792,7 @@ def test_export_core_attention( ...@@ -792,6 +792,7 @@ def test_export_core_attention(
if attn_mask_type is None: if attn_mask_type is None:
attn_mask_type = 'causal' attn_mask_type = 'causal'
inp = (query_layer, key_layer, value_layer)
model = te.transformer.DotProductAttention( model = te.transformer.DotProductAttention(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
kv_channels=kv_channels, kv_channels=kv_channels,
......
...@@ -16,6 +16,15 @@ THREADS_PER_WARP = 32 ...@@ -16,6 +16,15 @@ THREADS_PER_WARP = 32
THREADS_PER_BLOCK = 128 THREADS_PER_BLOCK = 128
_default_causal_mask = {}
def _get_default_causal_mask(sq: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input"""
if sq not in _default_causal_mask:
_default_causal_mask[sq] = torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
return _default_causal_mask[sq]
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
""" """
Fused operation which performs following three operations in sequence Fused operation which performs following three operations in sequence
...@@ -274,6 +283,10 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -274,6 +283,10 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.scale is not None: if self.scale is not None:
inp = inp * self.scale inp = inp * self.scale
if self.attn_mask_type == "causal":
mask = _get_default_causal_mask(inp.size()[2])
mask_output = self.mask_func(inp, mask) if mask is not None else inp mask_output = self.mask_func(inp, mask) if mask is not None else inp
probs = torch.nn.Softmax(dim=-1)(mask_output) probs = torch.nn.Softmax(dim=-1)(mask_output)
......
...@@ -27,6 +27,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -27,6 +27,7 @@ from transformer_engine.pytorch.utils import (
split_tensor_along_dim, split_tensor_along_dim,
cast_if_needed, cast_if_needed,
get_default_init_method, get_default_init_method,
get_device_compute_capability,
) )
from transformer_engine.pytorch.constants import ( from transformer_engine.pytorch.constants import (
AttnMaskTypes, AttnMaskTypes,
...@@ -220,9 +221,6 @@ class FlashAttention(torch.nn.Module): ...@@ -220,9 +221,6 @@ class FlashAttention(torch.nn.Module):
assert ( assert (
attn_mask_type == "causal" attn_mask_type == "causal"
), 'FlashAttention currently only supports causal attention mask.' ), 'FlashAttention currently only supports causal attention mask.'
assert (
attention_softmax_in_fp32
), 'FlashAttention currently only supports softmax compute in fp32.'
self.attn_causal_mask = attn_mask_type == "causal" self.attn_causal_mask = attn_mask_type == "causal"
self.norm_factor = norm_factor self.norm_factor = norm_factor
...@@ -230,6 +228,7 @@ class FlashAttention(torch.nn.Module): ...@@ -230,6 +228,7 @@ class FlashAttention(torch.nn.Module):
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.layer_number = layer_number self.layer_number = layer_number
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
def forward( def forward(
self, self,
...@@ -287,6 +286,11 @@ class DotProductAttention(torch.nn.Module): ...@@ -287,6 +286,11 @@ class DotProductAttention(torch.nn.Module):
representation subspaces as described in the paper: representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_. `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
.. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`attn_mask_type` is set to `"causal"`.
.. warning:: .. warning::
For the default attention mechanism, this module executes a non-deterministic version of For the default attention mechanism, this module executes a non-deterministic version of
...@@ -303,15 +307,6 @@ class DotProductAttention(torch.nn.Module): ...@@ -303,15 +307,6 @@ class DotProductAttention(torch.nn.Module):
number of key-value channels. number of key-value channels.
attention_dropout: float, default = 0.0 attention_dropout: float, default = 0.0
dropout probability for the dropout op during multi-head attention. dropout probability for the dropout op during multi-head attention.
layer_number: int, default = `None`
layer number of the current `DotProductAttention` when multiple such modules
are concatenated, for instance in consecutive transformer blocks.
apply_query_key_layer_scaling: bool, default = `False`
apply query-key layer scaling during BMM1
by a factor of `layer_number`
attention_softmax_in_fp32: bool, default = `True`
if set to `False`, softmax is executed in
the dtype of activation tensors.
attn_mask_type: {'causal', 'padding'}, default = `causal` attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation. type of attention mask passed into softmax operation.
...@@ -371,9 +366,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -371,9 +366,8 @@ class DotProductAttention(torch.nn.Module):
self.use_flash_attention = ( self.use_flash_attention = (
int(os.getenv("NVTE_FLASH_ATTN", "1")) int(os.getenv("NVTE_FLASH_ATTN", "1"))
and attention_softmax_in_fp32
and attn_mask_type == "causal" and attn_mask_type == "causal"
and not apply_query_key_layer_scaling and get_device_compute_capability() >= 8.0
) )
attn_kwargs = { attn_kwargs = {
...@@ -422,6 +416,11 @@ class DotProductAttention(torch.nn.Module): ...@@ -422,6 +416,11 @@ class DotProductAttention(torch.nn.Module):
""" """
Dot Product Attention Layer. Dot Product Attention Layer.
.. note::
Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
is set to `"causal"`.
.. note:: .. note::
Input tensors :attr:`query_layer`, :attr:`key_layer`, and :attr:`value_layer` Input tensors :attr:`query_layer`, :attr:`key_layer`, and :attr:`value_layer`
...@@ -448,8 +447,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -448,8 +447,7 @@ class DotProductAttention(torch.nn.Module):
""" """
use_flash_attention = self.use_flash_attention use_flash_attention = self.use_flash_attention
if (attention_mask is not None if (query_layer.dtype not in [torch.bfloat16, torch.float16]
or query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16] or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16] or value_layer.dtype not in [torch.bfloat16, torch.float16]
): ):
...@@ -515,6 +513,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -515,6 +513,7 @@ class MultiHeadAttention(torch.nn.Module):
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.params_dtype = params_dtype self.params_dtype = params_dtype
self.init_method = init_method self.init_method = init_method
self.attn_mask_type = attn_mask_type
if not fuse_qkv_params: if not fuse_qkv_params:
qkv_weight_interleaved = False qkv_weight_interleaved = False
...@@ -658,7 +657,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -658,7 +657,7 @@ class MultiHeadAttention(torch.nn.Module):
"""MultiHeadAttention FWD""" """MultiHeadAttention FWD"""
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
if attention_mask is not None: if self.attn_mask_type != "causal" and attention_mask is not None:
assert ( assert (
attention_mask.dtype == torch.bool attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor" ), "Attention mask must be a boolean tensor"
...@@ -836,6 +835,11 @@ class TransformerLayer(torch.nn.Module): ...@@ -836,6 +835,11 @@ class TransformerLayer(torch.nn.Module):
TransformerLayer is made up of an attention block and a feedforward network (MLP). TransformerLayer is made up of an attention block and a feedforward network (MLP).
This standard layer is based on the paper "Attention Is All You Need". This standard layer is based on the paper "Attention Is All You Need".
.. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`self_attn_mask_type` is set to `"causal"`.
Parameters Parameters
---------- ----------
hidden_size : int hidden_size : int
...@@ -983,6 +987,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -983,6 +987,7 @@ class TransformerLayer(torch.nn.Module):
self.apply_residual_connection_post_layernorm = ( self.apply_residual_connection_post_layernorm = (
apply_residual_connection_post_layernorm apply_residual_connection_post_layernorm
) )
self.self_attn_mask_type = self_attn_mask_type
assert ( assert (
self_attn_mask_type in AttnMaskTypes self_attn_mask_type in AttnMaskTypes
), f"self_attn_mask_type {self_attn_mask_type} not supported" ), f"self_attn_mask_type {self_attn_mask_type} not supported"
...@@ -1129,6 +1134,11 @@ class TransformerLayer(torch.nn.Module): ...@@ -1129,6 +1134,11 @@ class TransformerLayer(torch.nn.Module):
""" """
Transformer Layer: attention block and a feedforward network (MLP) Transformer Layer: attention block and a feedforward network (MLP)
.. note::
Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type`
is set to `"causal"`.
Parameters Parameters
---------- ----------
hidden_states : torch.Tensor hidden_states : torch.Tensor
...@@ -1163,7 +1173,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -1163,7 +1173,7 @@ class TransformerLayer(torch.nn.Module):
hidden_states = hidden_states.contiguous() hidden_states = hidden_states.contiguous()
if attention_mask is not None: if self.self_attn_mask_type != "causal" and attention_mask is not None:
assert ( assert (
attention_mask.dtype == torch.bool attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor" ), "Attention mask must be a boolean tensor"
......
...@@ -8,6 +8,13 @@ from typing import Any, Callable, Optional, Tuple ...@@ -8,6 +8,13 @@ from typing import Any, Callable, Optional, Tuple
import torch import torch
def get_device_compute_capability() -> float:
"""Returns the cuda compute capability of current GPU"""
major = torch.cuda.get_device_properties(torch.cuda.current_device()).major
minor = torch.cuda.get_device_properties(torch.cuda.current_device()).minor
return major + minor / 10
def attention_mask_func( def attention_mask_func(
attention_scores: torch.Tensor, attention_mask: torch.Tensor attention_scores: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment