Unverified Commit 79a9fe29 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

flash-attn integration (#62)



* add flash attention to TransformerLayer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Add docs for FP8 calibration (#61)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Fix the integer overflow in fused softmax (#60)
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* prefix flash attn env var with NVTE_
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Address steady memory increase and bloated checkpoints (#63)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix env var logic
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix flash attn env var logic again
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* remove d2d copies (#64)

* remove d2d copies
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Increase number of FP8 tensors per GEMM (#22)

* Increase number of FP8 tensors per GEMM
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Enable FP8 output tensor for fp8_gemm
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [BERT FP8] Initial TE review comments
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Temporary fix for cuda graph non convergence
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Address review comments-2
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Review comments-3
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Change for New API
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Remove unnecessary clone for D_scale, D_amax
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Avoid Roll for AMAX history size = 1
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Update onnx_te_gemm API
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Fix Lint errors
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Bug fixes from PR 22 (#65)

* Bug fixes from PR 22
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add FP8 tests to ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* bundle unittests for ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* replace rearrange with transpose
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* QKV parameters unfused path fixes and optimization (#66)

* Bug fixes from PR 22
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add FP8 tests to ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better QKV parameter fusion
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* small fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* keep original param for unfused case to retain externally set attrs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix ONNX exports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* improve arg naming
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* No need to set data pointers
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Assert memory loc in NoopCat
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Handle case of different memory in param and buffer
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix assert always true
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Reassign params memory to avoid more concats
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Fix gradients when using AMP (#70)

retain grad related attrs while casting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix pylint violations 

fixed pyline violations such as trailing white spaces and too long lines 
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>

* fix pylint violation on line 264 with R1719
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>

* fix two more pylint violations
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>

* DotProductAttention API
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add docs for attention
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix assert always true
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* check for correct flash-attn version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint+build fixes, correct settings for default flash-attn
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* correct version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments and fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix onnx and disable flash-attn export test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove einops dependency
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cleanup internal API; rm duplication
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* only install TE wheel (exclude flash-attn to rm conflicts)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* forgot to change install wheel path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* next round review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix flash_attn output
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix QK layer scaling
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* update docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments and fixes to selective checkpointing
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f06e2d85
......@@ -25,9 +25,9 @@ jobs:
uses: actions/upload-artifact@v3
with:
name: te_wheel
path: wheelhouse/*.whl
path: wheelhouse/transformer_engine*.whl
retention-days: 7
- name: 'Install'
run: pip install --no-cache-dir wheelhouse/*.whl
run: pip install --no-cache-dir wheelhouse/transformer_engine*.whl
- name: 'Sanity check'
run: python tests/test_sanity_import.py
......@@ -20,6 +20,9 @@ Modules
.. autoclass:: transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs)
:members: forward
.. autoclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
:members: forward
.. autoclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
:members: forward
......
......@@ -7,4 +7,5 @@ set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/*.py
pytest -v -s $TE_PATH/tests/test_transformerengine.py $TE_PATH/tests/test_fp8.py
NVTE_FLASH_ATTN=0 pytest -v -s $TE_PATH/tests/test_onnx_export.py
......@@ -313,5 +313,8 @@ setup(
description="Transformer acceleration library",
ext_modules=ext_modules,
cmdclass={"build_ext": TEBuildExtension},
install_requires = [
"flash-attn @ git+https://github.com/ksivaman/flash-attention.git@hopper",
],
license_files=("LICENSE",),
)
......@@ -793,7 +793,7 @@ def test_export_core_attention(
if attn_mask_type is None:
attn_mask_type = 'causal'
model = te.transformer.CoreAttention(
model = te.transformer.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
attention_dropout=0.5,
......
......@@ -7,6 +7,7 @@ from .module import LayerNormLinear
from .module import Linear
from .module import LayerNormMLP
from .module import LayerNorm
from .transformer import DotProductAttention
from .transformer import TransformerLayer
from .fp8 import fp8_autocast
from .distributed import checkpoint
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment