Unverified Commit 79a9fe29 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

flash-attn integration (#62)



* add flash attention to TransformerLayer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Add docs for FP8 calibration (#61)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Fix the integer overflow in fused softmax (#60)
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* prefix flash attn env var with NVTE_
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Address steady memory increase and bloated checkpoints (#63)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix env var logic
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix flash attn env var logic again
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* remove d2d copies (#64)

* remove d2d copies
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Increase number of FP8 tensors per GEMM (#22)

* Increase number of FP8 tensors per GEMM
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Enable FP8 output tensor for fp8_gemm
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [BERT FP8] Initial TE review comments
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Temporary fix for cuda graph non convergence
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Address review comments-2
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Review comments-3
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Change for New API
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Remove unnecessary clone for D_scale, D_amax
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Avoid Roll for AMAX history size = 1
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Update onnx_te_gemm API
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Fix Lint errors
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Bug fixes from PR 22 (#65)

* Bug fixes from PR 22
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add FP8 tests to ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* bundle unittests for ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* replace rearrange with transpose
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* QKV parameters unfused path fixes and optimization (#66)

* Bug fixes from PR 22
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add FP8 tests to ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better QKV parameter fusion
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* small fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* keep original param for unfused case to retain externally set attrs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix ONNX exports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* improve arg naming
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* No need to set data pointers
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Assert memory loc in NoopCat
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Handle case of different memory in param and buffer
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix assert always true
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Reassign params memory to avoid more concats
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Fix gradients when using AMP (#70)

retain grad related attrs while casting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix pylint violations 

fixed pyline violations such as trailing white spaces and too long lines 
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>

* fix pylint violation on line 264 with R1719
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>

* fix two more pylint violations
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>

* DotProductAttention API
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add docs for attention
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix assert always true
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* check for correct flash-attn version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint+build fixes, correct settings for default flash-attn
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* correct version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments and fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix onnx and disable flash-attn export test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove einops dependency
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cleanup internal API; rm duplication
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* only install TE wheel (exclude flash-attn to rm conflicts)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* forgot to change install wheel path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* next round review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix flash_attn output
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix QK layer scaling
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* update docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments and fixes to selective checkpointing
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f06e2d85
...@@ -25,9 +25,9 @@ jobs: ...@@ -25,9 +25,9 @@ jobs:
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
with: with:
name: te_wheel name: te_wheel
path: wheelhouse/*.whl path: wheelhouse/transformer_engine*.whl
retention-days: 7 retention-days: 7
- name: 'Install' - name: 'Install'
run: pip install --no-cache-dir wheelhouse/*.whl run: pip install --no-cache-dir wheelhouse/transformer_engine*.whl
- name: 'Sanity check' - name: 'Sanity check'
run: python tests/test_sanity_import.py run: python tests/test_sanity_import.py
...@@ -20,6 +20,9 @@ Modules ...@@ -20,6 +20,9 @@ Modules
.. autoclass:: transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs) .. autoclass:: transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs)
:members: forward :members: forward
.. autoclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
:members: forward
.. autoclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs) .. autoclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
:members: forward :members: forward
......
...@@ -7,4 +7,5 @@ set -e ...@@ -7,4 +7,5 @@ set -e
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5 onnxruntime==1.13.1 pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/*.py pytest -v -s $TE_PATH/tests/test_transformerengine.py $TE_PATH/tests/test_fp8.py
NVTE_FLASH_ATTN=0 pytest -v -s $TE_PATH/tests/test_onnx_export.py
...@@ -313,5 +313,8 @@ setup( ...@@ -313,5 +313,8 @@ setup(
description="Transformer acceleration library", description="Transformer acceleration library",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": TEBuildExtension}, cmdclass={"build_ext": TEBuildExtension},
install_requires = [
"flash-attn @ git+https://github.com/ksivaman/flash-attention.git@hopper",
],
license_files=("LICENSE",), license_files=("LICENSE",),
) )
...@@ -793,7 +793,7 @@ def test_export_core_attention( ...@@ -793,7 +793,7 @@ def test_export_core_attention(
if attn_mask_type is None: if attn_mask_type is None:
attn_mask_type = 'causal' attn_mask_type = 'causal'
model = te.transformer.CoreAttention( model = te.transformer.DotProductAttention(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
kv_channels=kv_channels, kv_channels=kv_channels,
attention_dropout=0.5, attention_dropout=0.5,
......
...@@ -7,6 +7,7 @@ from .module import LayerNormLinear ...@@ -7,6 +7,7 @@ from .module import LayerNormLinear
from .module import Linear from .module import Linear
from .module import LayerNormMLP from .module import LayerNormMLP
from .module import LayerNorm from .module import LayerNorm
from .transformer import DotProductAttention
from .transformer import TransformerLayer from .transformer import TransformerLayer
from .fp8 import fp8_autocast from .fp8 import fp8_autocast
from .distributed import checkpoint from .distributed import checkpoint
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment