Unverified Commit 7e593c3b authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add num_splits support for FA3 backend (#2380)



* [Common] Deleted unused header (#2324)

Deleted unused header
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [JAX] L1_jax_distributed_test suit with individual executions (#2321)

* L1 rework
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* comment out test_multi_process_grouped_gemm for now
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* rm e5m2 from test norm + MXFP8
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* for branch
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* clean up and tests
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* change tests
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [PyTorch debug] Fixes to debug tests failures (#2268)

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix:
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [PyTorch Debug] Add max_blockwise_dynamic_range stats (#2137)

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [JAX] Fix bug with pre scale bias  (#2300)

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [JAX] Try to use pre-downloaded dataset artifacts first (#2345)

* Try to use pre-downloaded dataset artifacts first
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Set HF_HUB_OFFLINE to disable any network calls to HF when the
pre-downloaded dataset is available
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* Fix out of bounds access in the FP4 dequantize kernel (#2346)
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* Make FP8 weights compatible with older MCore version (#2342)

* Make cast_master_weights_to_fp8 compatible with older MCore version
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Rename keep_columnwise to manual_post_all_gather_processing & Optimize unit test
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove redundant _test_mini_optimizer()
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

---------
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [JAX] Add test to check jaxpr that amax is reused for nvfp4 recipe (#2348)

* Add test to check jaxpr that amax is reused for nvfp4 recipe
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Move test to test_helper.py and rename file
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* Fix sharding of segment position to match id in ring attention. (#2349)
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* Disable cuDNN attention for known IMA and NaNs (#2344)

* Fix cuDNN backend selection for more case. Add CG as a option as well
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix logic
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cuDNN checks
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add more checks
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cuddn version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix error message
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add check for window size
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [JAX] Default to fused attention in JAX DPA (#2363)

* Default to fused attention in JAX DPA
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Consolidate documentation for DPA in JAX
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>

* Correctly update the documentation for defaults in JAX DPA
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>

---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* Update cudnn frontend to v1.16.0 (#2362)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [common] Remove kvpacked and qkvpacked attention functions for every kernel type. (#2287)

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* depracted compile time warning + \warning -> \deprecated
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* Move Triton to common  (#2359)

* move triton to common and change paths
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Formatting
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [JAX] Fused layers argument default values changed (#2347)

* Changing default activations in MLP, TransformerLayer, dropout rate after FC1 to 0, and return_layernorm_output to False
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Fixing the failing tests by hard coding  arguments to the previous values instead of relying on newer default values
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* remove comment from gpt
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor changes for num_splits logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace None with 1 as default
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix docstring
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix dtype in pack/unpack when FP8
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add fused_attn_supported constraint for some tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FA3 installation commands
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FA3 installation commands in DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* separate fused fp8 and f16 flags in tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* initialize fused_attn_supported_f16
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FA installation in L3 tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Co-authored-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarroot <root@gpu-h100-0496.cm.cluster>
Co-authored-by: default avatarPeter Dykas <wdykas@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Co-authored-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarKunlun Li <94586211+kunlunl@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: default avatarTeddy Do <tdophung@nvidia.com>
Co-authored-by: default avatarwdykas <73254672+wdykas@users.noreply.github.com>
parent 1df4a69f
......@@ -30,13 +30,13 @@ do
# Build Flash Attention
if [ "${fa_version}" \< "3.0.0" ]
then
pip3 install flash-attn==${fa_version}
pip3 install flash-attn==${fa_version} --no-build-isolation
else
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install
cd flash-attention/hopper && python setup.py install
python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flash_attn_3
wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py
cp flash_attn_interface.py $python_path/flash_attn_3/
cd ../../
fi
......
......@@ -117,7 +117,14 @@ model_configs_base = {
@pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False])
def test_dot_product_attention(
dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs
dtype,
model_configs,
model,
ckpt_attn,
workspace_opt,
qkv_layout,
swa,
pad_between_seqs,
):
"""Test DotProductAttention module"""
......@@ -308,6 +315,31 @@ def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
model_configs_num_splits = {
# test: ModelConfig(b, sq, hq, dqk)
"num_splits_1_0": ModelConfig(2, 2048, 24, 128, num_splits=2),
"num_splits_1_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, num_splits=4),
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_num_splits])
@pytest.mark.parametrize("model", model_configs_num_splits.keys())
def test_dpa_num_splits(dtype, model_configs, model):
"""Test DotProductAttention with FlashAttention-3 num_splits enabled"""
test_dot_product_attention(
dtype,
model_configs,
model,
False,
True,
None,
False,
False,
)
model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
......@@ -1152,6 +1184,8 @@ def _run_dot_product_attention(
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
# Only pass num_splits when exercising the FlashAttention path
num_splits=config.num_splits if backend == "FlashAttention" else 1,
)
max_logit = None
if config.return_max_logit:
......@@ -1786,9 +1820,10 @@ def test_mha_fp8_vs_f16(
fp8_meta=fp8_meta,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported < 1:
flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported_fp8 < 1:
pytest.skip("No FP8 attention backend available.")
fused_attn_supported_f16 = False
if not fp8_dpa_bwd:
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
......@@ -1796,8 +1831,8 @@ def test_mha_fp8_vs_f16(
qkv_layout=qkv_format.replace("hd", "h3d"),
is_training=is_training,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
_, fused_attn_supported_f16, _ = available_backends
if not fused_attn_supported_f16:
pytest.skip("No attention backend available.")
if flash_attn_supported:
......@@ -1809,6 +1844,7 @@ def test_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
if fused_attn_supported_fp8:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
......@@ -1817,6 +1853,10 @@ def test_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
if fused_attn_supported_f16:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
......@@ -1825,7 +1865,7 @@ def test_mha_fp8_vs_f16(
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.15
if flash_attn_supported:
if flash_attn_supported and fused_attn_supported_f16:
logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert(
......@@ -1838,6 +1878,7 @@ def test_mha_fp8_vs_f16(
rmse_tol,
True,
)
if fused_attn_supported_fp8 and fused_attn_supported_f16:
logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert(
......
......@@ -8,6 +8,7 @@ import logging
import os
from contextlib import contextmanager
from typing import Optional, Tuple, Dict, Any, List
from packaging.version import Version as PkgVersion
import torch
......@@ -210,6 +211,7 @@ class ModelConfig:
max_ctx_len: int = None,
num_layers: int = 1,
eps: float = 1e-5,
num_splits=1,
):
self.batch_size = batch_size
self.max_seqlen_q = max_seqlen_q
......@@ -239,6 +241,7 @@ class ModelConfig:
self.max_ctx_len = max_ctx_len
self.num_layers = num_layers
self.eps = eps
self.num_splits = num_splits
@contextmanager
......@@ -321,6 +324,9 @@ def get_available_attention_backends(
inference_params=inference_params,
softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit,
# allow all backends to pass so they can be used for testing;
# check for FA3 availability later
num_splits=1,
)
(
use_flash_attention,
......@@ -330,6 +336,10 @@ def get_available_attention_backends(
use_unfused_attention,
available_backends,
) = get_attention_backend(attention_params)
# Check if FA3 is an available backend when num_splits != 1
if available_backends[0]:
if config.num_splits != 1 and not flash_attention_backend > PkgVersion("3.0.0b"):
available_backends[0] = False
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
......
......@@ -681,6 +681,7 @@ class FlashAttention(torch.nn.Module):
inference_params: Optional[InferenceParams] = None,
flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"),
fp8_output: bool = False,
num_splits: Optional[int] = 1,
) -> torch.Tensor:
"""flash-attn fprop"""
......@@ -957,6 +958,7 @@ class FlashAttention(torch.nn.Module):
else:
fa_3_optional_forward_kwargs = {}
fa_3_optional_forward_kwargs["window_size"] = window_size
fa_3_optional_forward_kwargs["num_splits"] = num_splits
if inference_params is None:
fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
else:
......
......@@ -799,6 +799,7 @@ class DotProductAttention(TransformerEngineBaseModule):
inference_params: Optional[InferenceParams] = None,
pad_between_seqs: Optional[bool] = None,
fp8_output: Optional[bool] = False,
num_splits: Optional[int] = 1,
) -> torch.Tensor:
"""
Dot Product Attention Layer.
......@@ -973,6 +974,10 @@ class DotProductAttention(TransformerEngineBaseModule):
If true, there are padding tokens between individual sequences in a packed batch.
fp8_output: Optional[bool], default = `False`
Whether to enforce output to be in FP8 or not.
num_splits: Optional[int], default = 1
Optional split control for FlashAttention-3 only. When set, this value is forwarded
to the FA3 backend to control internal kernel splitting behavior for non-context-parallel
cases. It is ignored for other backends and when context parallelism is enabled.
"""
with self.prepare_forward(
......@@ -1315,6 +1320,7 @@ class DotProductAttention(TransformerEngineBaseModule):
softmax_type=self.softmax_type,
return_max_logit=self.return_max_logit,
cuda_graph=is_graph_capturing(),
num_splits=num_splits,
)
global _attention_backends
if is_in_onnx_export_mode():
......@@ -1413,6 +1419,7 @@ class DotProductAttention(TransformerEngineBaseModule):
inference_params=inference_params,
flash_attention_backend=flash_attention_backend,
fp8_output=fp8_output,
num_splits=num_splits,
)
if use_fused_attention:
......
......@@ -135,7 +135,7 @@ class FlashAttentionUtils:
# Please follow these instructions to install FA3
v3_installation_steps = """\
(1) git clone https://github.com/Dao-AILab/flash-attention.git
(2) cd flash-attention/ && git checkout 3ba6f82 && git submodule update --init && cd hopper/ && python setup.py install
(2) cd flash-attention/hopper && python setup.py install
(3) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(4) mkdir -p $python_path/flash_attn_3
(5) cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py"""
......@@ -233,6 +233,8 @@ class AttentionParams:
Whether to output max_logit.
cuda_graph: bool, default = `False`
Whether support for cuda graph capture is needed or not.
num_splits: int, default = 1
The number of kernels to split attention to.
"""
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
......@@ -263,6 +265,7 @@ class AttentionParams:
softmax_type: str = "vanilla"
return_max_logit: bool = False
cuda_graph: bool = False
num_splits: int = 1
def __eq__(self, other):
"""
......@@ -338,6 +341,7 @@ def get_attention_backend(
softmax_type = attention_params.softmax_type
return_max_logit = attention_params.return_max_logit
cuda_graph = attention_params.cuda_graph
num_splits = attention_params.num_splits
# Run config
logger = logging.getLogger("DotProductAttention")
......@@ -511,6 +515,18 @@ def get_attention_backend(
use_flash_attention = False
use_fused_attention = False
# Filter: num_splits
if num_splits != 1:
if use_flash_attention_2 and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention 2 for num_splits")
use_flash_attention_2 = False
if use_fused_attention:
logger.debug("Disabling FusedAttention for num_splits")
use_fused_attention = False
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention for num_splits")
use_unfused_attention = False
# Filter: Return max_logit
if return_max_logit:
if use_flash_attention:
......@@ -1566,8 +1582,9 @@ def _pack_tensor(
"""
Packs the given tensor using the `indices`.
"""
dtype = tensor.dtype if not isinstance(tensor, Float8Tensor) else torch.uint8
padding_indice = torch.zeros(
1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
1, tensor.shape[1], tensor.shape[2], dtype=dtype, device=tensor.device
)
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
if isinstance(tensor, Float8Tensor):
......@@ -1622,8 +1639,9 @@ def _unpack_tensor(
Inverse of `_pack_tensor`.
"""
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
dtype = tensor.dtype if not isinstance(tensor, Float8Tensor) else torch.uint8
unpacked = torch.zeros(
dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=dtype, device=tensor.device
)
if isinstance(tensor, Float8Tensor):
unpacked.scatter_(0, indices, tensor._data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment