__init__.py 2.16 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
"""Transformer Engine bindings for JAX.

This module provides JAX bindings for NVIDIA's Transformer Engine, enabling
high-performance transformer operations with mixed precision and quantization
support. It includes implementations of key transformer components like attention,
linear layers, and layer normalization, optimized for NVIDIA GPUs.

The module exports various transformer operations and utilities:
- Attention mechanisms (self-attention, cross-attention)
- Linear transformations with optional quantization
- Layer normalization operations
- Activation functions
- Softmax operations
- Sharding utilities for distributed training

All operations are designed to work seamlessly with JAX's functional programming
model and support automatic differentiation.
"""
22

23
# pylint: disable=wrong-import-position
24

25
26
27
28
# This unused import is needed because the top level `transformer_engine/__init__.py`
# file catches an `ImportError` as a guard for cases where the given framework's
# extensions are not available.
import jax
29

30
from transformer_engine.common import load_framework_extension
31

32
load_framework_extension("jax")
33

34
from . import flax
35
36
from . import quantize

37
38
from .quantize import fp8_autocast, update_collections, get_delayed_scaling
from .quantize import NVTE_FP8_COLLECTION_NAME
39

40
from .sharding import MeshResource
41
from .sharding import MajorShardingType, ShardingResource, ShardingType
42

43
44
45
from ..common.utils import deprecate_wrapper
from ..common.utils import DeprecatedEnum

46
47
48
MajorShardingType = DeprecatedEnum(
    MajorShardingType, "MajorShardingType is deprecating in the near feature."
)
49
50
51
ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.")
ShardingResource = deprecate_wrapper(
    ShardingResource,
52
53
    "ShardingResource is renamed to MeshResource, and will be removed in the near feature.",
)
54
55

__all__ = [
56
    "NVTE_FP8_COLLECTION_NAME",
57
    "fp8_autocast",
58
59
    "update_collections",
    "get_delayed_scaling",
60
61
62
63
64
    "MeshResource",
    "MajorShardingType",
    "ShardingResource",
    "ShardingType",
    "flax",
65
    "quantize",
66
]