Unverified Commit 8ace813c authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

Refactor attention.py part 2 (#1704)



* Move MultiHeadAttention into its own file. Modify tests and files in t_e/pytorch to import from the new MHA module
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Resolving lost MHA changes from PR 1614 as a result of rebase
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Move context parallelism code into it's own file. Modify test and local imports of cp code accordingly
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Move softmax.py frm pytorch/ to pytorch/d_p_a
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Move Unfused and Fused attention to backends.py and some utils functions to pytorch/utils.py
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Resolving lost mark_activation_offload changes from PR 1678 as a result of rebase
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Code clean up
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Refactor attention dir
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Refactor dir structure. Make relevant symbols public in __init__ for attention and d_p_a dirs
Move FA package imports to backends.py
Code cleanup
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Modify tests to import attention modules correctly
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Lint fixes
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Code clean up and fix typo
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Allowing InferenceParams and RoPE imports from attention module and pytorch module
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Allow InferenceParams and RoPE imports via transformer_engine.pytorch and transformer_engine.pytorch.attention modules
Remove unnecessary checks for check_set_window_size in MHA and TL
Reorder backends such that smaller classes at the start and larger ones at the end
Code clean up
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Reinstating changes from PR 1478 for rope.py lost during rebase conflict resolution
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix lint issues
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* nit: Code clean up
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Make imports leaner
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 6c942ffd
......@@ -458,7 +458,7 @@
" </tr>\n",
"</table>\n",
"\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"\n",
"<div class=\"alert alert-info\">\n",
"<b>Note</b>\n",
......
......@@ -8,11 +8,9 @@ import gc
from contextlib import contextmanager
import torch
from torch import nn
import transformer_engine as te
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding
from transformer_engine.pytorch.fp8 import fp8_model_init
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
import transformers
from transformers.models.llama.modeling_llama import (
......
......@@ -2,12 +2,16 @@
#
# See LICENSE for license information.
import os, sys, logging
import os
import sys
import logging
from contextlib import nullcontext
import torch
import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import get_cu_seqlens_on_cp_rank
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
get_cu_seqlens_on_cp_rank,
)
import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import functools
import logging
import math
import os
from importlib.metadata import version
from typing import Any, Dict, List, Tuple, Union, Optional
from contextlib import contextmanager
......@@ -15,26 +12,22 @@ import torch
from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init
from transformer_engine.pytorch.attention import (
from transformer_engine.pytorch.attention.dot_product_attention import (
DotProductAttention,
MultiheadAttention,
_attention_backends,
)
from transformer_engine.pytorch.dot_product_attention.utils import (
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils,
get_attention_backend,
check_set_window_size,
AttentionParams,
)
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
AttnBiasType,
AttnMaskType,
FusedAttnBackend,
QKVLayout,
fused_attn_bwd,
fused_attn_fwd,
)
......@@ -49,9 +42,7 @@ from transformer_engine.pytorch.utils import (
)
from transformer_engine.pytorch.utils import get_cudnn_version
import transformer_engine_torch as tex
from transformer_engine_torch import NVTE_Fused_Attn_Backend
from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
prepare_for_saving,
restore_from_saved,
......
......@@ -11,7 +11,7 @@ from transformer_engine.pytorch.utils import (
get_device_compute_capability,
get_cudnn_version,
)
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils
from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils
from test_fused_attn import ModelConfig
model_configs_flash_attn = {
......
......@@ -11,6 +11,12 @@ import math
import pytest
import torch
from test_fused_attn import (
ModelConfig,
reset_rng_states,
_get_attention_backends,
)
from torch.distributions import Exponential
from transformer_engine.pytorch import make_graphed_callables
from transformer_engine.common import recipe
......@@ -18,20 +24,15 @@ from transformer_engine.pytorch import fp8_autocast, fp8_model_init
from transformer_engine.pytorch.transformer import (
TransformerLayer,
)
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
from transformer_engine.pytorch.attention import DotProductAttention, InferenceParams
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils as fa_utils,
)
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
)
from test_fused_attn import (
ModelConfig,
reset_rng_states,
_get_attention_backends,
)
# Initialize RNG state
seed = 1234
......
......@@ -12,7 +12,7 @@ from torch import nn
from torch.testing._internal.common_device_type import largeTensorTest
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Callable, Tuple, Union
import math
import pytest
import torch
from typing import Callable, Tuple, Union
from transformer_engine.pytorch.dot_product_attention.rope import (
import pytest
from transformer_engine.pytorch.attention.rope import (
RotaryPositionEmbedding,
apply_rotary_pos_emb,
)
......
......@@ -7,7 +7,6 @@ import math
import os
from typing import Dict, List, Tuple, Optional
import pytest
import copy
import random
import torch
......@@ -38,7 +37,7 @@ from transformer_engine.pytorch import (
Fp8Padding,
Fp8Unpadding,
)
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
......
......@@ -90,7 +90,8 @@ from transformer_engine.pytorch.module import initialize_ub
from transformer_engine.pytorch.module import destroy_ub
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.permutation import (
moe_permute,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Python interface for attention"""
from .dot_product_attention import DotProductAttention
from .multi_head_attention import MultiheadAttention
from .inference import InferenceParams
from .rope import RotaryPositionEmbedding
__all__ = [
"DotProductAttention",
"MultiheadAttention",
"InferenceParams",
"RotaryPositionEmbedding",
]
......@@ -3,3 +3,7 @@
# See LICENSE for license information.
"""Python interface for dot product attention"""
from .dot_product_attention import DotProductAttention, _attention_backends
__all__ = ["DotProductAttention", "_attention_backends"]
......@@ -34,7 +34,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_O_CP,
META_DQKV_CP,
)
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.constants import TE_DType
......@@ -53,6 +53,8 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_cu_seqlens_cache = {}
class AttentionLogging:
"""
......@@ -63,6 +65,7 @@ class AttentionLogging:
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
fa_logger = logging.getLogger(__name__)
_is_logging_setup = False
@staticmethod
def setup_logging():
......@@ -77,6 +80,7 @@ class AttentionLogging:
AttentionLogging.fa_logger.setLevel(AttentionLogging._log_level)
if not AttentionLogging.fa_logger.hasHandlers():
AttentionLogging.fa_logger.addHandler(AttentionLogging._stream_handler)
AttentionLogging._is_logging_setup = True
@functools.lru_cache(maxsize=None)
......@@ -87,6 +91,11 @@ def _get_supported_versions(version_min, version_max):
return ">= " + str(version_min) + ", " + "<= " + str(version_max)
def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor:
"""Make tensor contiguous if final stride is not 1."""
return tensor.contiguous() if tensor.stride(-1) != 1 else tensor
class FlashAttentionUtils:
"""
Manage Flash Attention versioning information
......@@ -1295,9 +1304,6 @@ def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
return indices
_cu_seqlens_cache = {}
def get_full_cu_seqlens(
batch_size: int,
max_seqlen: int,
......
This diff is collapsed.
......@@ -246,7 +246,6 @@ def _apply_rotary_pos_emb_base(
# [seq, b, 1, dim] -> [b, seq, 1, dim]
if tensor_format == "bshd":
freqs = freqs.transpose(0, 1)
# cos/sin first then dtype conversion for better precision
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment