__init__.py 709 Bytes
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4

5
6
"""Transformer Engine bindings for Paddle"""

7
8
# pylint: disable=wrong-import-position,wrong-import-order

9

10
11
12
13
14
15
def _load_library():
    """Load shared library with Transformer Engine C extensions"""
    from transformer_engine import transformer_engine_paddle  # pylint: disable=unused-import


_load_library()
16
from .fp8 import fp8_autocast
17
18
19
20
21
22
23
24
25
26
27
from .layer import (
    Linear,
    LayerNorm,
    LayerNormLinear,
    LayerNormMLP,
    FusedScaleMaskSoftmax,
    DotProductAttention,
    MultiHeadAttention,
    TransformerLayer,
    RotaryPositionEmbedding,
)
Tian Zheng's avatar
Tian Zheng committed
28
from .recompute import recompute