__init__.py 1.96 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4

5
6
"""Transformer Engine bindings for Paddle"""

7
8
# pylint: disable=wrong-import-position,wrong-import-order

9
10
11
12
13
import logging
from importlib.metadata import version

from transformer_engine.common import is_package_installed

14

15
16
def _load_library():
    """Load shared library with Transformer Engine C extensions"""
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    module_name = "transformer_engine_paddle"

    if is_package_installed(module_name):
        assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`."
        assert is_package_installed(
            "transformer_engine_cu12"
        ), "Could not find `transformer-engine-cu12`."
        assert (
            version(module_name)
            == version("transformer-engine")
            == version("transformer-engine-cu12")
        ), (
            "TransformerEngine package version mismatch. Found"
            f" {module_name} v{version(module_name)}, transformer-engine"
            f" v{version('transformer-engine')}, and transformer-engine-cu12"
            f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install"
            " transformer-engine[paddle]==VERSION'"
        )

    if is_package_installed("transformer-engine-cu12"):
        if not is_package_installed(module_name):
            logging.info(
                "Could not find package %s. Install transformer-engine using 'pip"
                " install transformer-engine[paddle]==VERSION'",
                module_name,
            )

44
45
46
47
    from transformer_engine import transformer_engine_paddle  # pylint: disable=unused-import


_load_library()
48
from .fp8 import fp8_autocast
49
50
51
52
53
54
55
56
57
58
59
from .layer import (
    Linear,
    LayerNorm,
    LayerNormLinear,
    LayerNormMLP,
    FusedScaleMaskSoftmax,
    DotProductAttention,
    MultiHeadAttention,
    TransformerLayer,
    RotaryPositionEmbedding,
)
Tian Zheng's avatar
Tian Zheng committed
60
from .recompute import recompute