Unverified Commit aecd5a8f authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Fix FP8 logic related to FA2/FA3 (#1141)



* fix FP8 logic when FA3 is not installed
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweak to make logic more explicit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* limit FA3 warning to Hopper and NVTE_FLASH_ATTN=1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* prefer fused attn for FP8
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 8ddac3df
...@@ -22,6 +22,7 @@ from transformer_engine.pytorch.attention import ( ...@@ -22,6 +22,7 @@ from transformer_engine.pytorch.attention import (
get_attention_backend, get_attention_backend,
_flash_attn_2_plus, _flash_attn_2_plus,
_flash_attn_2_3_plus, _flash_attn_2_3_plus,
_flash_attn_3_plus,
check_set_window_size, check_set_window_size,
AttentionParams, AttentionParams,
_attention_backends, _attention_backends,
...@@ -135,7 +136,6 @@ def _get_attention_backends( ...@@ -135,7 +136,6 @@ def _get_attention_backends(
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1" os.environ["NVTE_UNFUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
alibi_slopes_shape = None alibi_slopes_shape = None
...@@ -678,7 +678,6 @@ def _run_dot_product_attention( ...@@ -678,7 +678,6 @@ def _run_dot_product_attention(
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0" os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
# Create seqlens # Create seqlens
...@@ -1167,7 +1166,6 @@ def _run_transformer_layer( ...@@ -1167,7 +1166,6 @@ def _run_transformer_layer(
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
# Create input tensor # Create input tensor
...@@ -1352,8 +1350,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1352,8 +1350,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model] config = model_configs_fp8_vs_f16[model]
global _attention_backends if _flash_attn_3_plus and not is_training:
if not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
...@@ -1379,7 +1376,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1379,7 +1376,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
rtol = 5e-1 rtol = 5e-1
rmse_tol = 0.15 rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("forward output"))
if not is_training: if _flash_attn_3_plus and not is_training:
_error( _error(
flash_attn_fwd_fp8, flash_attn_fwd_fp8,
fused_attn_fwd_f16, fused_attn_fwd_f16,
...@@ -1527,8 +1524,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): ...@@ -1527,8 +1524,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
global _attention_backends if _flash_attn_3_plus and not is_training:
if not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
...@@ -1555,7 +1551,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): ...@@ -1555,7 +1551,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
rmse_tol = 0.1 rmse_tol = 0.1
bwd_names = ["dq", "dk", "dv"] bwd_names = ["dq", "dk", "dv"]
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("forward output"))
if not is_training: if _flash_attn_3_plus and not is_training:
_error( _error(
flash_attn_fwd_fp8, flash_attn_fwd_fp8,
fused_attn_fwd_f16, fused_attn_fwd_f16,
...@@ -1778,7 +1774,6 @@ def _run_custom_mha_fp8(dtype, config, backend): ...@@ -1778,7 +1774,6 @@ def _run_custom_mha_fp8(dtype, config, backend):
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
inp = 0.0001 * torch.randint( inp = 0.0001 * torch.randint(
...@@ -1833,7 +1828,6 @@ def _run_ref_mha_f16(dtype, config, backend): ...@@ -1833,7 +1828,6 @@ def _run_ref_mha_f16(dtype, config, backend):
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
inp = torch.load("qkv.pt").to(device="cuda") inp = torch.load("qkv.pt").to(device="cuda")
......
...@@ -74,6 +74,9 @@ from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo ...@@ -74,6 +74,9 @@ from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
from transformer_engine.pytorch.graph import is_graph_capturing from transformer_engine.pytorch.graph import is_graph_capturing
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) _flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = PkgVersion("2.0.6") _flash_attn_version_required = PkgVersion("2.0.6")
_flash_attn_max_version = PkgVersion("2.6.3") _flash_attn_max_version = PkgVersion("2.6.3")
...@@ -89,13 +92,14 @@ try: ...@@ -89,13 +92,14 @@ try:
_flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper")) _flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
_flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1") _flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1")
except PackageNotFoundError: except PackageNotFoundError:
warnings.warn( if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN:
"To use flash-attn v3, please use the following commands to install: \n" warnings.warn(
"""(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n""" "To use flash-attn v3, please use the following commands to install: \n"
"""(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n""" """(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n"""
"""(3) mkdir -p $python_path/flashattn_hopper \n""" """(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n"""
"""(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" """(3) mkdir -p $python_path/flashattn_hopper \n"""
) """(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py"""
)
else: else:
from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3 from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flashattn_hopper.flash_attn_interface import ( from flashattn_hopper.flash_attn_interface import (
...@@ -137,10 +141,6 @@ _formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") ...@@ -137,10 +141,6 @@ _formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler() _stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter) _stream_handler.setFormatter(_formatter)
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
_attention_backends = { _attention_backends = {
"attention_params": None, "attention_params": None,
"use_flash_attention": None, "use_flash_attention": None,
...@@ -382,8 +382,13 @@ def get_attention_backend( ...@@ -382,8 +382,13 @@ def get_attention_backend(
# Filter: Execution type # Filter: Execution type
if fp8 and fp8_meta["recipe"].fp8_dpa: if fp8 and fp8_meta["recipe"].fp8_dpa:
if use_flash_attention and is_training: if use_flash_attention and not _use_flash_attn_3:
logger.debug("Disabling FlashAttention as it does not support FP8 training") logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8")
use_flash_attention = False
if use_flash_attention and _use_flash_attn_3 and is_training:
logger.debug(
"Disabling FlashAttention as FlashAttention 3 does not support FP8 training"
)
use_flash_attention = False use_flash_attention = False
if use_unfused_attention: if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8") logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
...@@ -826,6 +831,21 @@ def get_attention_backend( ...@@ -826,6 +831,21 @@ def get_attention_backend(
) )
use_flash_attention = False use_flash_attention = False
# Select FusedAttention for FP8
# FA3 uses default scaling factors (i.e. 1) in FP8 execution, while FusedAttention takes
# scaling factors from `fp8_meta` and offers more accurate quantization/de-quantization
if (
use_flash_attention
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["FP8"]
and _use_flash_attn_3
):
logger.debug(
"Disabling FlashAttention 3 to give FusedAttention preference as FusedAttention "
"supports more accurate scaling factors in FP8 execution"
)
use_flash_attention = False
# Selected backend # Selected backend
if use_flash_attention: if use_flash_attention:
use_fused_attention = False use_fused_attention = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment