__init__.py 3.85 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
"""Transformer Engine bindings for JAX.

This module provides JAX bindings for NVIDIA's Transformer Engine, enabling
high-performance transformer operations with mixed precision and quantization
support. It includes implementations of key transformer components like attention,
linear layers, and layer normalization, optimized for NVIDIA GPUs.

The module exports various transformer operations and utilities:
- Attention mechanisms (self-attention, cross-attention)
- Linear transformations with optional quantization
- Layer normalization operations
- Activation functions
- Softmax operations
- Sharding utilities for distributed training

All operations are designed to work seamlessly with JAX's functional programming
model and support automatic differentiation.
"""
22

23
24
# pylint: disable=wrong-import-position,wrong-import-order

25
import logging
26
27
import importlib
import importlib.util
28
from importlib.metadata import version
29
import sys
30

31
from transformer_engine.common import get_te_path, is_package_installed
32
33
34
35
36
from transformer_engine.common import _get_sys_extension


def _load_library():
    """Load shared library with Transformer Engine C extensions"""
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    module_name = "transformer_engine_jax"

    if is_package_installed(module_name):
        assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`."
        assert is_package_installed(
            "transformer_engine_cu12"
        ), "Could not find `transformer-engine-cu12`."
        assert (
            version(module_name)
            == version("transformer-engine")
            == version("transformer-engine-cu12")
        ), (
            "TransformerEngine package version mismatch. Found"
            f" {module_name} v{version(module_name)}, transformer-engine"
            f" v{version('transformer-engine')}, and transformer-engine-cu12"
52
53
            f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
            "'pip3 install transformer-engine[jax]==VERSION'"
54
55
56
57
        )

    if is_package_installed("transformer-engine-cu12"):
        if not is_package_installed(module_name):
58
            logging.info(
59
60
                "Could not find package %s. Install transformer-engine using "
                "'pip3 install transformer-engine[jax]==VERSION'",
61
62
63
                module_name,
            )

64
65
66
    extension = _get_sys_extension()
    try:
        so_dir = get_te_path() / "transformer_engine"
67
        so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
68
    except StopIteration:
69
70
71
72
73
74
        try:
            so_dir = get_te_path() / "transformer_engine" / "wheel_lib"
            so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
        except StopIteration:
            so_dir = get_te_path()
            so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
75

76
77
78
79
    spec = importlib.util.spec_from_file_location(module_name, so_path)
    solib = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = solib
    spec.loader.exec_module(solib)
80
81


82
_load_library()
83
from . import flax
84
85
86
87
from . import quantize

from .quantize import fp8_autocast

88
from .sharding import MeshResource
89
from .sharding import MajorShardingType, ShardingResource, ShardingType
90

91
92
93
from ..common.utils import deprecate_wrapper
from ..common.utils import DeprecatedEnum

94
95
96
MajorShardingType = DeprecatedEnum(
    MajorShardingType, "MajorShardingType is deprecating in the near feature."
)
97
98
99
ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.")
ShardingResource = deprecate_wrapper(
    ShardingResource,
100
101
    "ShardingResource is renamed to MeshResource, and will be removed in the near feature.",
)
102
103

__all__ = [
104
105
106
107
108
109
110
    "fp8_autocast",
    "MeshResource",
    "MajorShardingType",
    "ShardingResource",
    "ShardingType",
    "flax",
    "praxis",
111
]