Unverified Commit 7e593c3b authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add num_splits support for FA3 backend (#2380)



* [Common] Deleted unused header (#2324)

Deleted unused header
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [JAX] L1_jax_distributed_test suit with individual executions (#2321)

* L1 rework
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* comment out test_multi_process_grouped_gemm for now
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* rm e5m2 from test norm + MXFP8
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* for branch
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* clean up and tests
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* change tests
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [PyTorch debug] Fixes to debug tests failures (#2268)

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix:
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [PyTorch Debug] Add max_blockwise_dynamic_range stats (#2137)

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [JAX] Fix bug with pre scale bias  (#2300)

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [JAX] Try to use pre-downloaded dataset artifacts first (#2345)

* Try to use pre-downloaded dataset artifacts first
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Set HF_HUB_OFFLINE to disable any network calls to HF when the
pre-downloaded dataset is available
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* Fix out of bounds access in the FP4 dequantize kernel (#2346)
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* Make FP8 weights compatible with older MCore version (#2342)

* Make cast_master_weights_to_fp8 compatible with older MCore version
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Rename keep_columnwise to manual_post_all_gather_processing & Optimize unit test
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove redundant _test_mini_optimizer()
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

---------
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [JAX] Add test to check jaxpr that amax is reused for nvfp4 recipe (#2348)

* Add test to check jaxpr that amax is reused for nvfp4 recipe
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Move test to test_helper.py and rename file
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* Fix sharding of segment position to match id in ring attention. (#2349)
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* Disable cuDNN attention for known IMA and NaNs (#2344)

* Fix cuDNN backend selection for more case. Add CG as a option as well
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix logic
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cuDNN checks
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add more checks
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cuddn version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix error message
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add check for window size
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [JAX] Default to fused attention in JAX DPA (#2363)

* Default to fused attention in JAX DPA
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Consolidate documentation for DPA in JAX
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>

* Correctly update the documentation for defaults in JAX DPA
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>

---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* Update cudnn frontend to v1.16.0 (#2362)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [common] Remove kvpacked and qkvpacked attention functions for every kernel type. (#2287)

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* depracted compile time warning + \warning -> \deprecated
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* Move Triton to common  (#2359)

* move triton to common and change paths
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Formatting
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [JAX] Fused layers argument default values changed (#2347)

* Changing default activations in MLP, TransformerLayer, dropout rate after FC1 to 0, and return_layernorm_output to False
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Fixing the failing tests by hard coding  arguments to the previous values instead of relying on newer default values
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* remove comment from gpt
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor changes for num_splits logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace None with 1 as default
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix docstring
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix dtype in pack/unpack when FP8
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add fused_attn_supported constraint for some tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FA3 installation commands
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FA3 installation commands in DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* separate fused fp8 and f16 flags in tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* initialize fused_attn_supported_f16
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FA installation in L3 tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Co-authored-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarroot <root@gpu-h100-0496.cm.cluster>
Co-authored-by: default avatarPeter Dykas <wdykas@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Co-authored-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarKunlun Li <94586211+kunlunl@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: default avatarTeddy Do <tdophung@nvidia.com>
Co-authored-by: default avatarwdykas <73254672+wdykas@users.noreply.github.com>
parent 1df4a69f
...@@ -30,13 +30,13 @@ do ...@@ -30,13 +30,13 @@ do
# Build Flash Attention # Build Flash Attention
if [ "${fa_version}" \< "3.0.0" ] if [ "${fa_version}" \< "3.0.0" ]
then then
pip3 install flash-attn==${fa_version} pip3 install flash-attn==${fa_version} --no-build-isolation
else else
git clone https://github.com/Dao-AILab/flash-attention.git git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install cd flash-attention/hopper && python setup.py install
python_path=`python -c "import site; print(site.getsitepackages()[0])"` python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flash_attn_3 mkdir -p $python_path/flash_attn_3
wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py cp flash_attn_interface.py $python_path/flash_attn_3/
cd ../../ cd ../../
fi fi
......
...@@ -117,7 +117,14 @@ model_configs_base = { ...@@ -117,7 +117,14 @@ model_configs_base = {
@pytest.mark.parametrize("swa", [False]) @pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False]) @pytest.mark.parametrize("pad_between_seqs", [False])
def test_dot_product_attention( def test_dot_product_attention(
dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs dtype,
model_configs,
model,
ckpt_attn,
workspace_opt,
qkv_layout,
swa,
pad_between_seqs,
): ):
"""Test DotProductAttention module""" """Test DotProductAttention module"""
...@@ -308,6 +315,31 @@ def test_dpa_max_logit(dtype, model_configs, model, qkv_layout): ...@@ -308,6 +315,31 @@ def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False) test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
model_configs_num_splits = {
# test: ModelConfig(b, sq, hq, dqk)
"num_splits_1_0": ModelConfig(2, 2048, 24, 128, num_splits=2),
"num_splits_1_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, num_splits=4),
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_num_splits])
@pytest.mark.parametrize("model", model_configs_num_splits.keys())
def test_dpa_num_splits(dtype, model_configs, model):
"""Test DotProductAttention with FlashAttention-3 num_splits enabled"""
test_dot_product_attention(
dtype,
model_configs,
model,
False,
True,
None,
False,
False,
)
model_configs_softmax = { model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk) # test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8), "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
...@@ -1152,6 +1184,8 @@ def _run_dot_product_attention( ...@@ -1152,6 +1184,8 @@ def _run_dot_product_attention(
core_attention_bias=bias, core_attention_bias=bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
fast_zero_fill=True, fast_zero_fill=True,
# Only pass num_splits when exercising the FlashAttention path
num_splits=config.num_splits if backend == "FlashAttention" else 1,
) )
max_logit = None max_logit = None
if config.return_max_logit: if config.return_max_logit:
...@@ -1786,9 +1820,10 @@ def test_mha_fp8_vs_f16( ...@@ -1786,9 +1820,10 @@ def test_mha_fp8_vs_f16(
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
is_training=is_training, is_training=is_training,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported < 1: if flash_attn_supported + fused_attn_supported_fp8 < 1:
pytest.skip("No FP8 attention backend available.") pytest.skip("No FP8 attention backend available.")
fused_attn_supported_f16 = False
if not fp8_dpa_bwd: if not fp8_dpa_bwd:
available_backends, _, fused_attn_backends = get_available_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
...@@ -1796,8 +1831,8 @@ def test_mha_fp8_vs_f16( ...@@ -1796,8 +1831,8 @@ def test_mha_fp8_vs_f16(
qkv_layout=qkv_format.replace("hd", "h3d"), qkv_layout=qkv_format.replace("hd", "h3d"),
is_training=is_training, is_training=is_training,
) )
_, fused_attn_supported, _ = available_backends _, fused_attn_supported_f16, _ = available_backends
if not fused_attn_supported: if not fused_attn_supported_f16:
pytest.skip("No attention backend available.") pytest.skip("No attention backend available.")
if flash_attn_supported: if flash_attn_supported:
...@@ -1809,6 +1844,7 @@ def test_mha_fp8_vs_f16( ...@@ -1809,6 +1844,7 @@ def test_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
) )
if fused_attn_supported_fp8:
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
...@@ -1817,6 +1853,10 @@ def test_mha_fp8_vs_f16( ...@@ -1817,6 +1853,10 @@ def test_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
) )
if fused_attn_supported_f16:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
...@@ -1825,7 +1865,7 @@ def test_mha_fp8_vs_f16( ...@@ -1825,7 +1865,7 @@ def test_mha_fp8_vs_f16(
atol = 5e-1 atol = 5e-1
rtol = 5e-1 rtol = 5e-1
rmse_tol = 0.15 rmse_tol = 0.15
if flash_attn_supported: if flash_attn_supported and fused_attn_supported_f16:
logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:")) logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert( compare_and_assert(
...@@ -1838,6 +1878,7 @@ def test_mha_fp8_vs_f16( ...@@ -1838,6 +1878,7 @@ def test_mha_fp8_vs_f16(
rmse_tol, rmse_tol,
True, True,
) )
if fused_attn_supported_fp8 and fused_attn_supported_f16:
logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:")) logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert( compare_and_assert(
......
...@@ -8,6 +8,7 @@ import logging ...@@ -8,6 +8,7 @@ import logging
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Tuple, Dict, Any, List from typing import Optional, Tuple, Dict, Any, List
from packaging.version import Version as PkgVersion
import torch import torch
...@@ -210,6 +211,7 @@ class ModelConfig: ...@@ -210,6 +211,7 @@ class ModelConfig:
max_ctx_len: int = None, max_ctx_len: int = None,
num_layers: int = 1, num_layers: int = 1,
eps: float = 1e-5, eps: float = 1e-5,
num_splits=1,
): ):
self.batch_size = batch_size self.batch_size = batch_size
self.max_seqlen_q = max_seqlen_q self.max_seqlen_q = max_seqlen_q
...@@ -239,6 +241,7 @@ class ModelConfig: ...@@ -239,6 +241,7 @@ class ModelConfig:
self.max_ctx_len = max_ctx_len self.max_ctx_len = max_ctx_len
self.num_layers = num_layers self.num_layers = num_layers
self.eps = eps self.eps = eps
self.num_splits = num_splits
@contextmanager @contextmanager
...@@ -321,6 +324,9 @@ def get_available_attention_backends( ...@@ -321,6 +324,9 @@ def get_available_attention_backends(
inference_params=inference_params, inference_params=inference_params,
softmax_type=config.softmax_type, softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit, return_max_logit=config.return_max_logit,
# allow all backends to pass so they can be used for testing;
# check for FA3 availability later
num_splits=1,
) )
( (
use_flash_attention, use_flash_attention,
...@@ -330,6 +336,10 @@ def get_available_attention_backends( ...@@ -330,6 +336,10 @@ def get_available_attention_backends(
use_unfused_attention, use_unfused_attention,
available_backends, available_backends,
) = get_attention_backend(attention_params) ) = get_attention_backend(attention_params)
# Check if FA3 is an available backend when num_splits != 1
if available_backends[0]:
if config.num_splits != 1 and not flash_attention_backend > PkgVersion("3.0.0b"):
available_backends[0] = False
# Set attention.py _attention_backends var using return value # Set attention.py _attention_backends var using return value
# from get_attention_backend() # from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention _attention_backends["use_flash_attention"] = use_flash_attention
......
...@@ -681,6 +681,7 @@ class FlashAttention(torch.nn.Module): ...@@ -681,6 +681,7 @@ class FlashAttention(torch.nn.Module):
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"),
fp8_output: bool = False, fp8_output: bool = False,
num_splits: Optional[int] = 1,
) -> torch.Tensor: ) -> torch.Tensor:
"""flash-attn fprop""" """flash-attn fprop"""
...@@ -957,6 +958,7 @@ class FlashAttention(torch.nn.Module): ...@@ -957,6 +958,7 @@ class FlashAttention(torch.nn.Module):
else: else:
fa_3_optional_forward_kwargs = {} fa_3_optional_forward_kwargs = {}
fa_3_optional_forward_kwargs["window_size"] = window_size fa_3_optional_forward_kwargs["window_size"] = window_size
fa_3_optional_forward_kwargs["num_splits"] = num_splits
if inference_params is None: if inference_params is None:
fa_3_optional_forward_kwargs["deterministic"] = self.deterministic fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
else: else:
......
...@@ -799,6 +799,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -799,6 +799,7 @@ class DotProductAttention(TransformerEngineBaseModule):
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
pad_between_seqs: Optional[bool] = None, pad_between_seqs: Optional[bool] = None,
fp8_output: Optional[bool] = False, fp8_output: Optional[bool] = False,
num_splits: Optional[int] = 1,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Dot Product Attention Layer. Dot Product Attention Layer.
...@@ -973,6 +974,10 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -973,6 +974,10 @@ class DotProductAttention(TransformerEngineBaseModule):
If true, there are padding tokens between individual sequences in a packed batch. If true, there are padding tokens between individual sequences in a packed batch.
fp8_output: Optional[bool], default = `False` fp8_output: Optional[bool], default = `False`
Whether to enforce output to be in FP8 or not. Whether to enforce output to be in FP8 or not.
num_splits: Optional[int], default = 1
Optional split control for FlashAttention-3 only. When set, this value is forwarded
to the FA3 backend to control internal kernel splitting behavior for non-context-parallel
cases. It is ignored for other backends and when context parallelism is enabled.
""" """
with self.prepare_forward( with self.prepare_forward(
...@@ -1315,6 +1320,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1315,6 +1320,7 @@ class DotProductAttention(TransformerEngineBaseModule):
softmax_type=self.softmax_type, softmax_type=self.softmax_type,
return_max_logit=self.return_max_logit, return_max_logit=self.return_max_logit,
cuda_graph=is_graph_capturing(), cuda_graph=is_graph_capturing(),
num_splits=num_splits,
) )
global _attention_backends global _attention_backends
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
...@@ -1413,6 +1419,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1413,6 +1419,7 @@ class DotProductAttention(TransformerEngineBaseModule):
inference_params=inference_params, inference_params=inference_params,
flash_attention_backend=flash_attention_backend, flash_attention_backend=flash_attention_backend,
fp8_output=fp8_output, fp8_output=fp8_output,
num_splits=num_splits,
) )
if use_fused_attention: if use_fused_attention:
......
...@@ -135,7 +135,7 @@ class FlashAttentionUtils: ...@@ -135,7 +135,7 @@ class FlashAttentionUtils:
# Please follow these instructions to install FA3 # Please follow these instructions to install FA3
v3_installation_steps = """\ v3_installation_steps = """\
(1) git clone https://github.com/Dao-AILab/flash-attention.git (1) git clone https://github.com/Dao-AILab/flash-attention.git
(2) cd flash-attention/ && git checkout 3ba6f82 && git submodule update --init && cd hopper/ && python setup.py install (2) cd flash-attention/hopper && python setup.py install
(3) python_path=`python -c "import site; print(site.getsitepackages()[0])"` (3) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(4) mkdir -p $python_path/flash_attn_3 (4) mkdir -p $python_path/flash_attn_3
(5) cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py""" (5) cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py"""
...@@ -233,6 +233,8 @@ class AttentionParams: ...@@ -233,6 +233,8 @@ class AttentionParams:
Whether to output max_logit. Whether to output max_logit.
cuda_graph: bool, default = `False` cuda_graph: bool, default = `False`
Whether support for cuda graph capture is needed or not. Whether support for cuda graph capture is needed or not.
num_splits: int, default = 1
The number of kernels to split attention to.
""" """
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
...@@ -263,6 +265,7 @@ class AttentionParams: ...@@ -263,6 +265,7 @@ class AttentionParams:
softmax_type: str = "vanilla" softmax_type: str = "vanilla"
return_max_logit: bool = False return_max_logit: bool = False
cuda_graph: bool = False cuda_graph: bool = False
num_splits: int = 1
def __eq__(self, other): def __eq__(self, other):
""" """
...@@ -338,6 +341,7 @@ def get_attention_backend( ...@@ -338,6 +341,7 @@ def get_attention_backend(
softmax_type = attention_params.softmax_type softmax_type = attention_params.softmax_type
return_max_logit = attention_params.return_max_logit return_max_logit = attention_params.return_max_logit
cuda_graph = attention_params.cuda_graph cuda_graph = attention_params.cuda_graph
num_splits = attention_params.num_splits
# Run config # Run config
logger = logging.getLogger("DotProductAttention") logger = logging.getLogger("DotProductAttention")
...@@ -511,6 +515,18 @@ def get_attention_backend( ...@@ -511,6 +515,18 @@ def get_attention_backend(
use_flash_attention = False use_flash_attention = False
use_fused_attention = False use_fused_attention = False
# Filter: num_splits
if num_splits != 1:
if use_flash_attention_2 and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention 2 for num_splits")
use_flash_attention_2 = False
if use_fused_attention:
logger.debug("Disabling FusedAttention for num_splits")
use_fused_attention = False
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention for num_splits")
use_unfused_attention = False
# Filter: Return max_logit # Filter: Return max_logit
if return_max_logit: if return_max_logit:
if use_flash_attention: if use_flash_attention:
...@@ -1566,8 +1582,9 @@ def _pack_tensor( ...@@ -1566,8 +1582,9 @@ def _pack_tensor(
""" """
Packs the given tensor using the `indices`. Packs the given tensor using the `indices`.
""" """
dtype = tensor.dtype if not isinstance(tensor, Float8Tensor) else torch.uint8
padding_indice = torch.zeros( padding_indice = torch.zeros(
1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device 1, tensor.shape[1], tensor.shape[2], dtype=dtype, device=tensor.device
) )
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2]) indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
if isinstance(tensor, Float8Tensor): if isinstance(tensor, Float8Tensor):
...@@ -1622,8 +1639,9 @@ def _unpack_tensor( ...@@ -1622,8 +1639,9 @@ def _unpack_tensor(
Inverse of `_pack_tensor`. Inverse of `_pack_tensor`.
""" """
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2]) indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
dtype = tensor.dtype if not isinstance(tensor, Float8Tensor) else torch.uint8
unpacked = torch.zeros( unpacked = torch.zeros(
dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=dtype, device=tensor.device
) )
if isinstance(tensor, Float8Tensor): if isinstance(tensor, Float8Tensor):
unpacked.scatter_(0, indices, tensor._data) unpacked.scatter_(0, indices, tensor._data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment