__init__.py 2.19 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
5
6

from . import flax
Ming-Xu Huang's avatar
Ming-Xu Huang committed
7
from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling
8
from .sharding import MajorShardingType, ShardingResource, ShardingType
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from ..common.utils import deprecate_wrapper

extend_logical_axis_rules = deprecate_wrapper(
    flax.extend_logical_axis_rules,
    "extend_logical_axis_rules is moving to transformer_engine.jax.flax module")
DenseGeneral = deprecate_wrapper(flax.DenseGeneral,
                                 "DenseGeneral is moving to transformer_engine.jax.flax module")
LayerNorm = deprecate_wrapper(flax.LayerNorm,
                              "LayerNorm is moving to transformer_engine.jax.flax module")
LayerNormDenseGeneral = deprecate_wrapper(
    flax.LayerNormDenseGeneral,
    "LayerNormDenseGeneral is moving to transformer_engine.jax.flax module")
LayerNormMLP = deprecate_wrapper(flax.LayerNormMLP,
                                 "LayerNormMLP is moving to transformer_engine.jax.flax module")
TransformerEngineBase = deprecate_wrapper(
    flax.TransformerEngineBase,
    "TransformerEngineBase is moving to transformer_engine.jax.flax module")
MultiHeadAttention = deprecate_wrapper(
    flax.MultiHeadAttention, "MultiHeadAttention is moving to transformer_engine.jax.flax module")
RelativePositionBiases = deprecate_wrapper(
    flax.RelativePositionBiases,
    "RelativePositionBiases is moving to transformer_engine.jax.flax module")
TransformerLayer = deprecate_wrapper(
    flax.TransformerLayer, "TransformerLayer is moving to transformer_engine.jax.flax module")
TransformerLayerType = deprecate_wrapper(
    flax.TransformerLayerType,
    "TransformerLayerType is moving to transformer_engine.jax.flax module")

__all__ = [
    'fp8_autocast', 'update_collections', 'update_fp8_metas', 'get_delayed_scaling',
39
40
41
    'MajorShardingType', 'ShardingResource', 'ShardingType', 'flax', 'praxis', 'DenseGeneral',
    'LayerNorm', 'LayerNormDenseGeneral', 'LayerNormMLP', 'TransformerEngineBase',
    'MultiHeadAttention', 'RelativePositionBiases', 'TransformerLayer', 'TransformerLayerType'
42
]