Unverified Commit 79a9fe29 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

flash-attn integration (#62)



* add flash attention to TransformerLayer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Add docs for FP8 calibration (#61)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Fix the integer overflow in fused softmax (#60)
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* prefix flash attn env var with NVTE_
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Address steady memory increase and bloated checkpoints (#63)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix env var logic
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix flash attn env var logic again
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* remove d2d copies (#64)

* remove d2d copies
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Increase number of FP8 tensors per GEMM (#22)

* Increase number of FP8 tensors per GEMM
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Enable FP8 output tensor for fp8_gemm
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [BERT FP8] Initial TE review comments
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Temporary fix for cuda graph non convergence
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Address review comments-2
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Review comments-3
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Change for New API
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Remove unnecessary clone for D_scale, D_amax
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Avoid Roll for AMAX history size = 1
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Update onnx_te_gemm API
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Fix Lint errors
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Bug fixes from PR 22 (#65)

* Bug fixes from PR 22
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add FP8 tests to ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* bundle unittests for ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* replace rearrange with transpose
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* QKV parameters unfused path fixes and optimization (#66)

* Bug fixes from PR 22
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add FP8 tests to ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better QKV parameter fusion
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* small fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* keep original param for unfused case to retain externally set attrs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix ONNX exports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* improve arg naming
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* No need to set data pointers
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Assert memory loc in NoopCat
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Handle case of different memory in param and buffer
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix assert always true
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Reassign params memory to avoid more concats
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Fix gradients when using AMP (#70)

retain grad related attrs while casting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix pylint violations 

fixed pyline violations such as trailing white spaces and too long lines 
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>

* fix pylint violation on line 264 with R1719
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>

* fix two more pylint violations
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>

* DotProductAttention API
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add docs for attention
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix assert always true
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* check for correct flash-attn version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint+build fixes, correct settings for default flash-attn
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* correct version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments and fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix onnx and disable flash-attn export test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove einops dependency
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cleanup internal API; rm duplication
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* only install TE wheel (exclude flash-attn to rm conflicts)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* forgot to change install wheel path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* next round review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix flash_attn output
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix QK layer scaling
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* update docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments and fixes to selective checkpointing
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarcyanguwa <cyang.uwa@gmail.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f06e2d85
...@@ -25,9 +25,9 @@ jobs: ...@@ -25,9 +25,9 @@ jobs:
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
with: with:
name: te_wheel name: te_wheel
path: wheelhouse/*.whl path: wheelhouse/transformer_engine*.whl
retention-days: 7 retention-days: 7
- name: 'Install' - name: 'Install'
run: pip install --no-cache-dir wheelhouse/*.whl run: pip install --no-cache-dir wheelhouse/transformer_engine*.whl
- name: 'Sanity check' - name: 'Sanity check'
run: python tests/test_sanity_import.py run: python tests/test_sanity_import.py
...@@ -20,6 +20,9 @@ Modules ...@@ -20,6 +20,9 @@ Modules
.. autoclass:: transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs) .. autoclass:: transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs)
:members: forward :members: forward
.. autoclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
:members: forward
.. autoclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs) .. autoclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
:members: forward :members: forward
......
...@@ -7,4 +7,5 @@ set -e ...@@ -7,4 +7,5 @@ set -e
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5 onnxruntime==1.13.1 pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/*.py pytest -v -s $TE_PATH/tests/test_transformerengine.py $TE_PATH/tests/test_fp8.py
NVTE_FLASH_ATTN=0 pytest -v -s $TE_PATH/tests/test_onnx_export.py
...@@ -313,5 +313,8 @@ setup( ...@@ -313,5 +313,8 @@ setup(
description="Transformer acceleration library", description="Transformer acceleration library",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": TEBuildExtension}, cmdclass={"build_ext": TEBuildExtension},
install_requires = [
"flash-attn @ git+https://github.com/ksivaman/flash-attention.git@hopper",
],
license_files=("LICENSE",), license_files=("LICENSE",),
) )
...@@ -793,7 +793,7 @@ def test_export_core_attention( ...@@ -793,7 +793,7 @@ def test_export_core_attention(
if attn_mask_type is None: if attn_mask_type is None:
attn_mask_type = 'causal' attn_mask_type = 'causal'
model = te.transformer.CoreAttention( model = te.transformer.DotProductAttention(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
kv_channels=kv_channels, kv_channels=kv_channels,
attention_dropout=0.5, attention_dropout=0.5,
......
...@@ -7,6 +7,7 @@ from .module import LayerNormLinear ...@@ -7,6 +7,7 @@ from .module import LayerNormLinear
from .module import Linear from .module import Linear
from .module import LayerNormMLP from .module import LayerNormMLP
from .module import LayerNorm from .module import LayerNorm
from .transformer import DotProductAttention
from .transformer import TransformerLayer from .transformer import TransformerLayer
from .fp8 import fp8_autocast from .fp8 import fp8_autocast
from .distributed import checkpoint from .distributed import checkpoint
......
...@@ -4,12 +4,15 @@ ...@@ -4,12 +4,15 @@
"""Transformer.""" """Transformer."""
import os import os
import re
import math import math
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union from typing import Any, Callable, Optional, Tuple, Union
import torch import torch
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
from transformer_engine.pytorch import LayerNormLinear, Linear, LayerNormMLP, LayerNorm from transformer_engine.pytorch import LayerNormLinear, Linear, LayerNormMLP, LayerNorm
from transformer_engine.pytorch.jit import ( from transformer_engine.pytorch.jit import (
set_jit_fusion_options, set_jit_fusion_options,
...@@ -37,6 +40,11 @@ from transformer_engine.pytorch.distributed import ( ...@@ -37,6 +40,11 @@ from transformer_engine.pytorch.distributed import (
checkpoint, checkpoint,
) )
_flash_attn_version = re.search("Version: (.*)", os.popen("pip show flash_attn").read()).group(1)
__all__ = ["DotProductAttention", "TransformerLayer"]
class DropPath(torch.nn.Module): class DropPath(torch.nn.Module):
"""Drop paths (Stochastic Depth) per sample """Drop paths (Stochastic Depth) per sample
...@@ -63,66 +71,35 @@ class DropPath(torch.nn.Module): ...@@ -63,66 +71,35 @@ class DropPath(torch.nn.Module):
return output return output
class CoreAttention(torch.nn.Module): class UnfusedDotProductAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms """Parallel attention w/o QKV and Proj Gemms
BMM1 -> softmax + dropout -> BMM2 BMM1 -> softmax + dropout -> BMM2
""" """
def __init__( def __init__(
self, self,
num_attention_heads: int, norm_factor: float,
kv_channels: int, attention_dropout: float = 0.0,
attention_dropout: float, attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
apply_query_key_layer_scaling: bool = True, apply_query_key_layer_scaling: bool = False,
attention_softmax_in_fp32: bool = False, attention_softmax_in_fp32: bool = True,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
sequence_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
if layer_number is None:
self.apply_query_key_layer_scaling = False
else:
self.layer_number = max(1, layer_number)
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.attn_mask_type = attn_mask_type
projection_size = kv_channels * num_attention_heads
assert ( assert (
attn_mask_type in AttnMaskTypes attn_mask_type in AttnMaskTypes
), f"attn_mask_type {attn_mask_type} not supported" ), f"attn_mask_type {attn_mask_type} not supported"
# Per attention head and per partition values. self.norm_factor = norm_factor
self.hidden_size_per_partition = divide(projection_size, tp_size) self.attention_dropout_ctx = attention_dropout_ctx
self.hidden_size_per_attention_head = divide(
projection_size, num_attention_heads
)
self.sequence_parallel = sequence_parallel
if self.sequence_parallel or get_rng_state_tracker is None:
self.attention_dropout_ctx = nullcontext
else:
self.attention_dropout_ctx = get_rng_state_tracker().fork
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax( self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.attn_mask_type, attn_mask_type,
attention_mask_func, attention_mask_func,
self.attention_softmax_in_fp32, attention_softmax_in_fp32,
coeff, layer_number if apply_query_key_layer_scaling else None,
) )
# Dropout. Note that for a single iteration, this layer will generate # Dropout. Note that for a single iteration, this layer will generate
...@@ -135,9 +112,11 @@ class CoreAttention(torch.nn.Module): ...@@ -135,9 +112,11 @@ class CoreAttention(torch.nn.Module):
query_layer: torch.Tensor, query_layer: torch.Tensor,
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor, value_layer: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""core attention fprop""" """core attention fprop"""
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
# [b, np, sq, sk] # [b, np, sq, sk]
output_size = ( output_size = (
query_layer.size(1), query_layer.size(1),
...@@ -211,14 +190,290 @@ class CoreAttention(torch.nn.Module): ...@@ -211,14 +190,290 @@ class CoreAttention(torch.nn.Module):
context_layer = context_layer.permute(2, 0, 1, 3).contiguous() context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp] # [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + ( context_layer = context_layer.view(seqlen, batch_size, -1)
self.hidden_size_per_partition,
)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer return context_layer
class FlashAttention(torch.nn.Module):
"""Dot product attention implementation by using the flash-attn package.
"""
def __init__(
self,
norm_factor: float,
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None,
apply_query_key_layer_scaling: bool = False,
attention_softmax_in_fp32: bool = True,
attn_mask_type: str = "causal",
) -> None:
super().__init__()
if "dev" not in _flash_attn_version:
raise ImportError(
'Please install correct version of flash-attn with ' \
'pip install git+https://github.com/ksivaman/flash-attention.git@hopper. ' \
'If running on Hopper, ' \
'please install from source with compute capability 9.0.')
assert (
attn_mask_type == "causal"
), 'FlashAttention currently only supports causal attention mask.'
assert (
attention_softmax_in_fp32
), 'FlashAttention currently only supports softmax compute in fp32.'
self.attn_causal_mask = attn_mask_type == "causal"
self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx
self.attention_dropout = attention_dropout
self.layer_number = layer_number
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
def forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""flash-attn fprop"""
assert (
(query_layer.dtype in [torch.float16, torch.bfloat16])
and (key_layer.dtype in [torch.float16, torch.bfloat16])
and (value_layer.dtype in [torch.float16, torch.bfloat16])
), 'FlashAttention currently only supports FP16 and BF16.'
assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), 'FlashAttention currently only supports CUDA tensors.'
assert (
attention_mask is None
), 'FlashAttention currently does not support external attention mask.'
query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
for x in (query_layer, key_layer, value_layer)]
batch_size, seqlen = query_layer.shape[0], query_layer.shape[1]
# [b, sq, np, hn]
query_layer, key_layer, value_layer = [
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
for x in [query_layer, key_layer, value_layer]
]
max_seqlen = seqlen
cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen,
step=seqlen,
dtype=torch.int32,
device=query_layer.device)
with self.attention_dropout_ctx():
output = flash_attn_unpadded_func(
query_layer, key_layer, value_layer, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
self.attention_dropout if self.training else 0.0,
softmax_scale=1.0/self.norm_factor, causal=self.attn_causal_mask
)
# [(b sq), np, hn] -> [sq, b, (np hn)]
return output.view(batch_size, seqlen, -1).transpose(0, 1).contiguous()
class DotProductAttention(torch.nn.Module):
"""Allows the model to jointly attend to information from different
representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
.. warning::
For the default attention mechanism, this module executes a non-deterministic version of
`flash-attn <https://github.com/ksivaman/flash-attention>`_ whenever possible in order to
achieve optimal performance. To observe deterministic behavior, set the environment
variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order to disable
`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
Parameters
----------
num_attention_heads : int
number of attention heads in the transformer layer.
kv_channels : int
number of key-value channels.
attention_dropout: float, default = 0.0
dropout probability for the dropout op during multi-head attention.
layer_number: int, default = `None`
layer number of the current `DotProductAttention` when multiple such modules
are concatenated, for instance in consecutive transformer blocks.
apply_query_key_layer_scaling: bool, default = `False`
apply query-key layer scaling during BMM1
by a factor of `layer_number`
attention_softmax_in_fp32: bool, default = `True`
if set to `False`, softmax is executed in
the dtype of activation tensors.
attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_size : int, default = 1
tensor parallel world size.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
"""
def __init__(
self,
num_attention_heads: int,
kv_channels: int,
attention_dropout: float = 0.0,
layer_number: Optional[int] = None,
apply_query_key_layer_scaling: bool = False,
attention_softmax_in_fp32: bool = True,
attn_mask_type: str = "causal",
sequence_parallel: bool = False,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
tp_group: Optional[dist_group_type] = None,
) -> None:
super().__init__()
if layer_number is None:
apply_query_key_layer_scaling = False
else:
layer_number = max(1, layer_number)
if apply_query_key_layer_scaling:
attention_softmax_in_fp32 = True
tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_group = tp_group
self.get_rng_state_tracker = get_rng_state_tracker
projection_size = kv_channels * num_attention_heads
self.hidden_size_per_partition = divide(projection_size, tp_size)
self.hidden_size_per_attention_head = divide(
projection_size, num_attention_heads
)
if sequence_parallel or get_rng_state_tracker is None:
attention_dropout_ctx = nullcontext
else:
attention_dropout_ctx = get_rng_state_tracker().fork
norm_factor = math.sqrt(self.hidden_size_per_attention_head)
norm_factor_flash_attn = norm_factor
if apply_query_key_layer_scaling:
norm_factor *= layer_number
self.use_flash_attention = (
int(os.getenv("NVTE_FLASH_ATTN", "1"))
and attention_softmax_in_fp32
and attn_mask_type == "causal"
and not apply_query_key_layer_scaling
)
attn_kwargs = {
"attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx,
"layer_number": layer_number,
"apply_query_key_layer_scaling": apply_query_key_layer_scaling,
"attention_softmax_in_fp32": attention_softmax_in_fp32,
"attn_mask_type": attn_mask_type,
}
if self.use_flash_attention:
self.flash_attention = FlashAttention(norm_factor_flash_attn, **attn_kwargs)
# Instantiating both types since use of flash-attn
# might be ruled out due to forward inputs.
self.unfused_attention = UnfusedDotProductAttention(norm_factor, **attn_kwargs)
def _checkpointed_attention_forward(
self,
attention_func: Callable,
*forward_args: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
"""Forward method with activation checkpointing."""
def custom_forward(*inputs):
return attention_func(*inputs)
hidden_states = checkpoint(
custom_forward,
False,
self.get_rng_state_tracker,
self.tp_group,
*forward_args,
)
return hidden_states
def forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
checkpoint_core_attention: bool = False,
) -> torch.Tensor:
"""
Dot Product Attention Layer.
.. note::
Input tensors :attr:`query_layer`, :attr:`key_layer`, and :attr:`value_layer`
must each be of shape (:attr:`sequence_length`, :attr:`batch_size`,
:attr:`num_attention_heads`, :attr:`kv_channels`). Output of shape
(:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`
* :attr:`kv_channels`) is returned.
Parameters
----------
query_layer : torch.Tensor
Query tensor.
key_layer : torch.Tensor
Key tensor.
value_layer : torch.Tensor
Value tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out softmax input when not using flash-attn.
checkpoint_core_attention : bool, default = `False`
If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
"""
use_flash_attention = self.use_flash_attention
if (attention_mask is not None
or query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
):
use_flash_attention = False
if use_flash_attention:
if checkpoint_core_attention:
return self._checkpointed_attention_forward(self.flash_attention,
query_layer,
key_layer,
value_layer)
return self.flash_attention(query_layer, key_layer, value_layer)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.unfused_attention,
query_layer,
key_layer,
value_layer,
attention_mask,
)
return self.unfused_attention(query_layer, key_layer, value_layer, attention_mask)
class MultiHeadAttention(torch.nn.Module): class MultiHeadAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms """Parallel attention w/o QKV and Proj Gemms
BMM1 -> softmax + dropout -> BMM2 BMM1 -> softmax + dropout -> BMM2
...@@ -234,8 +489,8 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -234,8 +489,8 @@ class MultiHeadAttention(torch.nn.Module):
init_method: Callable, init_method: Callable,
output_layer_init_method: Callable, output_layer_init_method: Callable,
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
apply_query_key_layer_scaling: bool = True, apply_query_key_layer_scaling: bool = False,
attention_softmax_in_fp32: bool = False, attention_softmax_in_fp32: bool = True,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
tp_size: int = 1, tp_size: int = 1,
...@@ -248,7 +503,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -248,7 +503,7 @@ class MultiHeadAttention(torch.nn.Module):
attention_type: str = "self", attention_type: str = "self",
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
fuse_qkv_params: bool = False, fuse_qkv_params: bool = False,
zero_centered_gamma:bool = False, zero_centered_gamma: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_number = (layer_number,) self.layer_number = (layer_number,)
...@@ -343,8 +598,8 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -343,8 +598,8 @@ class MultiHeadAttention(torch.nn.Module):
**common_gemm_kwargs, **common_gemm_kwargs,
) )
# Core Self attention. # Attention.
self.core_attention = CoreAttention( self.core_attention = DotProductAttention(
num_attention_heads, num_attention_heads,
kv_channels, kv_channels,
attention_dropout, attention_dropout,
...@@ -355,6 +610,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -355,6 +610,7 @@ class MultiHeadAttention(torch.nn.Module):
get_rng_state_tracker=get_rng_state_tracker, get_rng_state_tracker=get_rng_state_tracker,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
sequence_parallel=sequence_parallel, sequence_parallel=sequence_parallel,
tp_group=tp_group,
) )
# Linear # Linear
...@@ -368,37 +624,6 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -368,37 +624,6 @@ class MultiHeadAttention(torch.nn.Module):
**common_gemm_kwargs, **common_gemm_kwargs,
) )
def _checkpointed_core_attention_forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Forward method with activation checkpointing."""
def custom_forward(*inputs):
query_layer = inputs[0]
key_layer = inputs[1]
value_layer = inputs[2]
attention_mask = inputs[3]
output_ = self.core_attention(
query_layer, key_layer, value_layer, attention_mask
)
return output_
hidden_states = checkpoint(
custom_forward,
False,
self.get_rng_state_tracker,
self.tp_group,
query_layer,
key_layer,
value_layer,
attention_mask,
)
return hidden_states
def _allocate_memory( def _allocate_memory(
self, inference_max_sequence_len: int, batch_size: int self, inference_max_sequence_len: int, batch_size: int
...@@ -419,10 +644,10 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -419,10 +644,10 @@ class MultiHeadAttention(torch.nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
encoder_output: Optional[torch.Tensor] = None, encoder_output: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: Optional[bool] = None, checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None, inference_params: Optional[Any] = None,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""MultiHeadAttention FWD""" """MultiHeadAttention FWD"""
...@@ -556,13 +781,12 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -556,13 +781,12 @@ class MultiHeadAttention(torch.nn.Module):
# core attention computation # core attention computation
# ================================== # ==================================
if checkpoint_core_attention:
context_layer = self._checkpointed_core_attention_forward(
query_layer, key_layer, value_layer, attention_mask
)
else:
context_layer = self.core_attention( context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask query_layer,
key_layer,
value_layer,
attention_mask,
checkpoint_core_attention=checkpoint_core_attention,
) )
# ================= # =================
...@@ -613,16 +837,16 @@ class TransformerLayer(torch.nn.Module): ...@@ -613,16 +837,16 @@ class TransformerLayer(torch.nn.Module):
layer_number: int, default = `None` layer_number: int, default = `None`
layer number of the current `TransformerLayer` when multiple such modules are layer number of the current `TransformerLayer` when multiple such modules are
concatenated to form a transformer block. concatenated to form a transformer block.
apply_query_key_layer_scaling: bool, default = `True` apply_query_key_layer_scaling: bool, default = `False`
apply query-key layer scaling during BMM1 apply query-key layer scaling during BMM1
by a factor of `layer_number` by a factor of `layer_number`
output_layernorm: bool, default = `False` output_layernorm: bool, default = `False`
if set to `True`, layer normalization is applied on the output side, if set to `True`, layer normalization is applied on the output side,
after the final dropout-add. default behavior is to apply layer after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation. normalization on the input side, before the QKV transformation.
attention_softmax_in_fp32: bool, default = `False` attention_softmax_in_fp32: bool, default = `True`
if set to `True`, softmax is executed in if set to `False`, softmax is executed in
torch.float32 dtype (single precision) the dtype of activation tensors.
layer_type: {'encoder', 'decoder'}, default = `encoder` layer_type: {'encoder', 'decoder'}, default = `encoder`
if set to `decoder`, an additional cross-attn block is added after self-attn. if set to `decoder`, an additional cross-attn block is added after self-attn.
This can be used for structures like `T5` Transformer in conjunction with the This can be used for structures like `T5` Transformer in conjunction with the
...@@ -702,8 +926,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -702,8 +926,8 @@ class TransformerLayer(torch.nn.Module):
params_dtype: torch.dtype = torch.float32, params_dtype: torch.dtype = torch.float32,
get_rng_state_tracker: Optional[Callable] = None, get_rng_state_tracker: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False, fuse_wgrad_accumulation: bool = False,
apply_query_key_layer_scaling: bool = True, apply_query_key_layer_scaling: bool = False,
attention_softmax_in_fp32: bool = False, attention_softmax_in_fp32: bool = True,
seq_length: Optional[int] = None, seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None, micro_batch_size: Optional[int] = None,
sequence_parallel: bool = False, sequence_parallel: bool = False,
...@@ -770,7 +994,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -770,7 +994,7 @@ class TransformerLayer(torch.nn.Module):
"return_layernorm_output": apply_residual_connection_post_layernorm, "return_layernorm_output": apply_residual_connection_post_layernorm,
"set_parallel_mode": set_parallel_mode, "set_parallel_mode": set_parallel_mode,
"fuse_qkv_params": fuse_qkv_params, "fuse_qkv_params": fuse_qkv_params,
"zero_centered_gamma": zero_centered_gamma "zero_centered_gamma": zero_centered_gamma,
} }
self.self_attention = MultiHeadAttention( self.self_attention = MultiHeadAttention(
...@@ -856,11 +1080,11 @@ class TransformerLayer(torch.nn.Module): ...@@ -856,11 +1080,11 @@ class TransformerLayer(torch.nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
encoder_output: Optional[torch.Tensor] = None, encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[torch.Tensor] = None, enc_dec_attn_mask: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: Optional[bool] = False, checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None, inference_params: Optional[Any] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -870,12 +1094,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -870,12 +1094,12 @@ class TransformerLayer(torch.nn.Module):
---------- ----------
hidden_states : torch.Tensor hidden_states : torch.Tensor
Input tensor. Input tensor.
attention_mask : torch.Tensor attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input. Boolean tensor used to mask out self-attention softmax input.
encoder_output : torch.Tensor encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`. `layer_type="decoder"`.
enc_dec_attn_mask : torch.Tensor enc_dec_attn_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out inter-attention softmax input if using Boolean tensor used to mask out inter-attention softmax input if using
`layer_type="decoder"`. `layer_type="decoder"`.
is_first_microbatch : {True, False, None}, default = None is_first_microbatch : {True, False, None}, default = None
...@@ -891,7 +1115,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -891,7 +1115,7 @@ class TransformerLayer(torch.nn.Module):
* it also allows skipping gradient accumulation during the * it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
checkpoint_core_attention: bool, default = `True` checkpoint_core_attention: bool, default = `False`
If true, forward activations for core attention are recomputed If true, forward activations for core attention are recomputed
during the backward pass in order to save memory that would during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until otherwise be occupied to store the forward activations until
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment