Unverified Commit e762592e authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Miscellaneous fixes for FA3 attention (#1174)



* add qkv descales to FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix sbhd shapes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* force the same dtype when comparing FA3 and cuDNN FP8
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "force the same dtype when comparing FA3 and cuDNN FP8"

This reverts commit 19e7f877026a19a32d2f02c6c9de20df4ae2e064.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force the same dtype when comparing FA3 and cuDNN FP8
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add try/except for FA3 when custom qkv descales are not supported
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace FA3 installation warning with a debug logging message
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove unused imports
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* avoid varlen_func for FP8 and improve messaging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add SWA support for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* change preference reason for FP8 logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent c3b3cd21
...@@ -13,7 +13,7 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py ...@@ -13,7 +13,7 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
......
...@@ -1319,6 +1319,8 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): ...@@ -1319,6 +1319,8 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item())) logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item()))
logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item())) logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item()))
try: try:
if a.dtype != b.dtype:
a = a.to(b.dtype)
torch.testing.assert_close(a, b, atol=atol, rtol=rtol) torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e: except Exception as e:
logging.debug(e) logging.debug(e)
......
...@@ -85,6 +85,16 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode ...@@ -85,6 +85,16 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
from transformer_engine.pytorch.graph import is_graph_capturing from transformer_engine.pytorch.graph import is_graph_capturing
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
...@@ -100,29 +110,31 @@ _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") ...@@ -100,29 +110,31 @@ _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
_flash_attn_3_plus = False _flash_attn_3_plus = False
_use_flash_attn_3 = False _use_flash_attn_3 = False
_flash_attn_3_installation_steps = """\
(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper"
(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(3) mkdir -p $python_path/flashattn_hopper
(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py"""
try: try:
_flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper")) _flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
_flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1") _flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.9")
_flash_attn_3_0_0_beta = _flash_attn_3_plus and _flash_attn_v3_version < PkgVersion("3.0.0")
except PackageNotFoundError: except PackageNotFoundError:
if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN: if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN:
warnings.warn( fa3_logger = logging.getLogger()
"To use flash-attn v3, please use the following commands to install: \n" fa3_logger.setLevel(_log_level)
"""(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n""" if not fa3_logger.hasHandlers():
"""(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n""" fa3_logger.addHandler(_stream_handler)
"""(3) mkdir -p $python_path/flashattn_hopper \n""" fa3_logger.debug(
"""(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" "To use flash-attn v3, please follow these steps to install the flashattn-hopper "
"package: \n%s",
_flash_attn_3_installation_steps,
) )
else: else:
from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3 from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flashattn_hopper.flash_attn_interface import ( from flashattn_hopper.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3,
) )
from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import
_flash_attn_forward as _flash_attn_forward_v3,
)
from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import
_flash_attn_backward as _flash_attn_backward_v3,
)
_use_flash_attn_3 = True _use_flash_attn_3 = True
...@@ -132,18 +144,6 @@ if _flash_attn_version >= _flash_attn_version_required: ...@@ -132,18 +144,6 @@ if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)
_attention_backends = { _attention_backends = {
"attention_params": None, "attention_params": None,
"use_flash_attention": None, "use_flash_attention": None,
...@@ -348,7 +348,7 @@ def get_attention_backend( ...@@ -348,7 +348,7 @@ def get_attention_backend(
use_fused_attention = False use_fused_attention = False
# Filter: Compute capability # Filter: Compute capability
global _flash_attn_3_plus, _use_flash_attn_3 global _use_flash_attn_3
if device_compute_capability < (8, 0): if device_compute_capability < (8, 0):
if use_flash_attention: if use_flash_attention:
logger.debug("Disabling FlashAttention as it requires compute capability sm80+") logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
...@@ -357,7 +357,7 @@ def get_attention_backend( ...@@ -357,7 +357,7 @@ def get_attention_backend(
logger.debug("Disabling FusedAttention as it requires compute capability sm80+") logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
use_fused_attention = False use_fused_attention = False
if device_compute_capability < (9, 0): if device_compute_capability < (9, 0):
if use_flash_attention and _flash_attn_3_plus: if use_flash_attention and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+") logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+")
_use_flash_attn_3 = False _use_flash_attn_3 = False
...@@ -438,8 +438,7 @@ def get_attention_backend( ...@@ -438,8 +438,7 @@ def get_attention_backend(
use_flash_attention = False use_flash_attention = False
# Filter: Dropout # Filter: Dropout
if attention_dropout != 0.0 and use_flash_attention: if attention_dropout != 0.0 and use_flash_attention and _use_flash_attn_3:
if _flash_attn_3_plus and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for dropout") logger.debug("Disabling FlashAttention 3 for dropout")
_use_flash_attn_3 = False _use_flash_attn_3 = False
...@@ -461,7 +460,7 @@ def get_attention_backend( ...@@ -461,7 +460,7 @@ def get_attention_backend(
) )
use_unfused_attention = False use_unfused_attention = False
if context_parallel and use_flash_attention: if context_parallel and use_flash_attention:
if _flash_attn_3_plus and _use_flash_attn_3: if _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for context parallelism") logger.debug("Disabling FlashAttention 3 for context parallelism")
_use_flash_attn_3 = False _use_flash_attn_3 = False
if fp8 and fp8_meta["recipe"].fp8_dpa: if fp8 and fp8_meta["recipe"].fp8_dpa:
...@@ -556,7 +555,7 @@ def get_attention_backend( ...@@ -556,7 +555,7 @@ def get_attention_backend(
use_fused_attention = False use_fused_attention = False
if ( if (
use_flash_attention use_flash_attention
and _flash_attn_3_plus and _use_flash_attn_3
and attn_mask_type in ["causal", "padding_causal"] and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv and max_seqlen_q != max_seqlen_kv
): ):
...@@ -590,6 +589,15 @@ def get_attention_backend( ...@@ -590,6 +589,15 @@ def get_attention_backend(
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
) )
use_flash_attention = False use_flash_attention = False
if (
use_flash_attention
and _use_flash_attn_3
and fp8
and fp8_meta["recipe"].fp8_dpa
and "padding" in attn_mask_type
):
logger.debug("Disabling FlashAttention 3 for FP8 and padding masks")
_use_flash_attn_3 = False
# Filter: Sliding window attention # Filter: Sliding window attention
# backend | window_size | diagonal alignment # backend | window_size | diagonal alignment
...@@ -633,15 +641,6 @@ def get_attention_backend( ...@@ -633,15 +641,6 @@ def get_attention_backend(
attn_mask_type, attn_mask_type,
) )
use_fused_attention = False use_fused_attention = False
if (
use_flash_attention
and (window_size[0] != -1 or window_size[1] not in [-1, 0])
and _flash_attn_3_plus
):
logger.debug(
"Disabling FlashAttention 3 as it does not support sliding window attention"
)
_use_flash_attn_3 = False
if ( if (
use_flash_attention use_flash_attention
and (window_size[0] != -1 or window_size[1] not in [-1, 0]) and (window_size[0] != -1 or window_size[1] not in [-1, 0])
...@@ -662,11 +661,11 @@ def get_attention_backend( ...@@ -662,11 +661,11 @@ def get_attention_backend(
# UnfusedDotProductAttention | no_bias, pre/post_scale_bias | # UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
# | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias # | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias
if use_flash_attention and core_attention_bias_type == "alibi": if use_flash_attention and core_attention_bias_type == "alibi":
if _flash_attn_3_plus and _use_flash_attn_3: if _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for ALiBi") logger.debug("Disabling FlashAttention 3 for ALiBi")
_use_flash_attn_3 = False _use_flash_attn_3 = False
if not _flash_attn_2_4_plus: if not _use_flash_attn_3 and not _flash_attn_2_4_plus:
logger.debug("Disabling FlashAttention for ALiBi") logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
use_flash_attention = False use_flash_attention = False
if use_flash_attention and ( if use_flash_attention and (
...@@ -827,10 +826,6 @@ def get_attention_backend( ...@@ -827,10 +826,6 @@ def get_attention_backend(
"for performance reasons" "for performance reasons"
) )
use_flash_attention = False use_flash_attention = False
# Select FusedAttention for FP8
# FA3 uses default scaling factors (i.e. 1) in FP8 execution, while FusedAttention takes
# scaling factors from `fp8_meta` and offers more accurate quantization/de-quantization
if ( if (
use_flash_attention use_flash_attention
and use_fused_attention and use_fused_attention
...@@ -838,8 +833,8 @@ def get_attention_backend( ...@@ -838,8 +833,8 @@ def get_attention_backend(
and _use_flash_attn_3 and _use_flash_attn_3
): ):
logger.debug( logger.debug(
"Disabling FlashAttention 3 to give FusedAttention preference as FusedAttention " "Disabling FlashAttention 3 to give FusedAttention preference for performance reasons "
"supports more accurate scaling factors in FP8 execution" "in FP8 execution"
) )
use_flash_attention = False use_flash_attention = False
...@@ -4963,6 +4958,10 @@ class FlashAttention(torch.nn.Module): ...@@ -4963,6 +4958,10 @@ class FlashAttention(torch.nn.Module):
self.attention_type = attention_type self.attention_type = attention_type
self.layer_number = 1 if layer_number is None else layer_number self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic self.deterministic = deterministic
self.logger = logging.getLogger("FlashAttention")
self.logger.setLevel(_log_level)
if not self.logger.hasHandlers():
self.logger.addHandler(_stream_handler)
def forward( def forward(
self, self,
...@@ -5033,6 +5032,10 @@ class FlashAttention(torch.nn.Module): ...@@ -5033,6 +5032,10 @@ class FlashAttention(torch.nn.Module):
x.transpose(0, 1) x.transpose(0, 1)
for x in (query_layer._data, key_layer._data, value_layer._data) for x in (query_layer._data, key_layer._data, value_layer._data)
] ]
query_layer, key_layer, value_layer = [
Float8Tensor.make_like(x, data=x._data)
for x in (query_layer, key_layer, value_layer)
]
if context_parallel: if context_parallel:
query_layer._data, key_layer._data, value_layer._data = [ query_layer._data, key_layer._data, value_layer._data = [
x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
...@@ -5168,24 +5171,43 @@ class FlashAttention(torch.nn.Module): ...@@ -5168,24 +5171,43 @@ class FlashAttention(torch.nn.Module):
fa_optional_forward_args_thd.append(max_seqlen_q) fa_optional_forward_args_thd.append(max_seqlen_q)
fa_optional_forward_args_thd.append(max_seqlen_kv) fa_optional_forward_args_thd.append(max_seqlen_kv)
if _use_flash_attn_3: if _use_flash_attn_3:
fa_3_optional_forward_kwargs = {}
fa_3_optional_forward_kwargs["window_size"] = window_size
fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
if fp8: if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
activation_dtype = query_layer.dtype activation_dtype = query_layer.dtype
torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)
def convert_to_torch_float8(tensor, dtype):
out = torch.Tensor().to(device=tensor.device, dtype=dtype)
out.set_(
tensor._data.untyped_storage(),
tensor._data.storage_offset(),
tensor._data.shape,
tensor._data.stride(),
)
return out
if fp8_meta["recipe"].fp8_mha: if fp8_meta["recipe"].fp8_mha:
assert all( assert all(
isinstance(x, Float8Tensor) isinstance(x, Float8Tensor)
for x in [query_layer, key_layer, value_layer] for x in [query_layer, key_layer, value_layer]
), "q/k/v must be Float8Tensors for FP8 MHA." ), "q/k/v must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv
else:
query_layer, key_layer, value_layer = ( query_layer, key_layer, value_layer = (
x.to(activation_dtype).to(torch_dtype) Float8Tensor.to_float8(x, fp8_dtype=fp8_dtype_forward)
for x in [query_layer, key_layer, value_layer] for x in [query_layer, key_layer, value_layer]
) )
else: fa_3_optional_forward_kwargs["descale_q"] = query_layer._scale_inv
fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv
fa_3_optional_forward_kwargs["descale_v"] = value_layer._scale_inv
query_layer, key_layer, value_layer = ( query_layer, key_layer, value_layer = (
x.to(torch_dtype) for x in [query_layer, key_layer, value_layer] convert_to_torch_float8(x, torch_dtype)
for x in [query_layer, key_layer, value_layer]
) )
try:
output, _ = func( output, _ = func(
query_layer, query_layer,
key_layer, key_layer,
...@@ -5193,8 +5215,18 @@ class FlashAttention(torch.nn.Module): ...@@ -5193,8 +5215,18 @@ class FlashAttention(torch.nn.Module):
*fa_optional_forward_args_thd, *fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type, causal="causal" in attn_mask_type,
deterministic=self.deterministic, **fa_3_optional_forward_kwargs,
) )
except TypeError as e:
if _flash_attn_3_0_0_beta:
e.args = (
e.args[0]
+ ". Please update your FlashAttention 3 (beta) installation as it "
+ "may have added more supported arguments to its API. \n"
+ _flash_attn_3_installation_steps,
) + e.args[1:]
raise
if fp8 and fp8_meta["recipe"].fp8_mha: if fp8 and fp8_meta["recipe"].fp8_mha:
output = cast_to_fp8( output = cast_to_fp8(
output, output,
...@@ -5228,8 +5260,12 @@ class FlashAttention(torch.nn.Module): ...@@ -5228,8 +5260,12 @@ class FlashAttention(torch.nn.Module):
if qkv_format == "sbhd": if qkv_format == "sbhd":
# (bs)hd -> bs(hd) -> sb(hd) # (bs)hd -> bs(hd) -> sb(hd)
if fp8 and fp8_meta["recipe"].fp8_mha: if fp8 and fp8_meta["recipe"].fp8_mha:
output.reshape(batch_size * max_seqlen_q // cp_size, -1).transpose_2d() output = Float8Tensor.make_like(
output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) output,
data=output._data.reshape(batch_size, max_seqlen_q // cp_size, -1)
.transpose(0, 1)
.contiguous(),
)
else: else:
output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1) output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1)
elif qkv_format == "bshd": elif qkv_format == "bshd":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment