Unverified Commit aecd5a8f authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Fix FP8 logic related to FA2/FA3 (#1141)



* fix FP8 logic when FA3 is not installed
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweak to make logic more explicit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* limit FA3 warning to Hopper and NVTE_FLASH_ATTN=1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* prefer fused attn for FP8
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 8ddac3df
......@@ -22,6 +22,7 @@ from transformer_engine.pytorch.attention import (
get_attention_backend,
_flash_attn_2_plus,
_flash_attn_2_3_plus,
_flash_attn_3_plus,
check_set_window_size,
AttentionParams,
_attention_backends,
......@@ -135,7 +136,6 @@ def _get_attention_backends(
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
alibi_slopes_shape = None
......@@ -678,7 +678,6 @@ def _run_dot_product_attention(
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
# Create seqlens
......@@ -1167,7 +1166,6 @@ def _run_transformer_layer(
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
# Create input tensor
......@@ -1352,8 +1350,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]
global _attention_backends
if not is_training:
if _flash_attn_3_plus and not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
......@@ -1379,7 +1376,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
rtol = 5e-1
rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output"))
if not is_training:
if _flash_attn_3_plus and not is_training:
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
......@@ -1527,8 +1524,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
global _attention_backends
if not is_training:
if _flash_attn_3_plus and not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
......@@ -1555,7 +1551,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
rmse_tol = 0.1
bwd_names = ["dq", "dk", "dv"]
logging.debug("========== {:^25s} ==========".format("forward output"))
if not is_training:
if _flash_attn_3_plus and not is_training:
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
......@@ -1778,7 +1774,6 @@ def _run_custom_mha_fp8(dtype, config, backend):
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
inp = 0.0001 * torch.randint(
......@@ -1833,7 +1828,6 @@ def _run_ref_mha_f16(dtype, config, backend):
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
inp = torch.load("qkv.pt").to(device="cuda")
......
......@@ -74,6 +74,9 @@ from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
from transformer_engine.pytorch.graph import is_graph_capturing
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = PkgVersion("2.0.6")
_flash_attn_max_version = PkgVersion("2.6.3")
......@@ -89,6 +92,7 @@ try:
_flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
_flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1")
except PackageNotFoundError:
if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN:
warnings.warn(
"To use flash-attn v3, please use the following commands to install: \n"
"""(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n"""
......@@ -137,10 +141,6 @@ _formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
_attention_backends = {
"attention_params": None,
"use_flash_attention": None,
......@@ -382,8 +382,13 @@ def get_attention_backend(
# Filter: Execution type
if fp8 and fp8_meta["recipe"].fp8_dpa:
if use_flash_attention and is_training:
logger.debug("Disabling FlashAttention as it does not support FP8 training")
if use_flash_attention and not _use_flash_attn_3:
logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8")
use_flash_attention = False
if use_flash_attention and _use_flash_attn_3 and is_training:
logger.debug(
"Disabling FlashAttention as FlashAttention 3 does not support FP8 training"
)
use_flash_attention = False
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
......@@ -826,6 +831,21 @@ def get_attention_backend(
)
use_flash_attention = False
# Select FusedAttention for FP8
# FA3 uses default scaling factors (i.e. 1) in FP8 execution, while FusedAttention takes
# scaling factors from `fp8_meta` and offers more accurate quantization/de-quantization
if (
use_flash_attention
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["FP8"]
and _use_flash_attn_3
):
logger.debug(
"Disabling FlashAttention 3 to give FusedAttention preference as FusedAttention "
"supports more accurate scaling factors in FP8 execution"
)
use_flash_attention = False
# Selected backend
if use_flash_attention:
use_fused_attention = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment