Unverified Commit 0792ded4 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Adjust Module Structure. (#169)



* Adjust Module Structure.

1. Collect Flax related modules to a sub-folder, flax.
2. Add a function to unify scale_init for zero-centered-gamma LN.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Make changes be compatible to previous versions.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adapt jax/examples to the new module structure.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Update jax/docs and Add deprecated warning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Update README
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding deprecated_wrapper
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding deprecated warning to flax modules which imported via transformer_engine.jax
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix CI errors and update docs.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Removing unnecessary deprecated warning in docs.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Implementing __iter__ to DeprecatedEnum.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 522fecc1
......@@ -69,6 +69,9 @@ pyTorch
JAX
^^^
Flax
~~~~
.. code-block:: python
import jax
......@@ -90,7 +93,7 @@ JAX
# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
model = te.DenseGeneral(features=HIDDEN)
model = te.flax.DenseGeneral(features=HIDDEN)
def loss_fn(params, other_vars, inp):
out = model.apply({'params':params, **other_vars}, inp)
......
......@@ -9,34 +9,33 @@ Jax
.. autoapiclass:: transformer_engine.jax.MajorShardingType
.. autoapiclass:: transformer_engine.jax.ShardingType
.. autoapiclass:: transformer_engine.jax.TransformerLayerType
.. autoapiclass:: transformer_engine.jax.ShardingResource(dp_resource=None, tp_resource=None)
.. autoapiclass:: transformer_engine.jax.ShardingResource(dp_resource=None, tp_resource=None)
.. autoapifunction:: transformer_engine.jax.fp8_autocast
.. autoapifunction:: transformer_engine.jax.update_collections
.. autoapifunction:: transformer_engine.jax.update_fp8_metas
.. autoapiclass:: transformer_engine.jax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs)
.. autoapiclass:: transformer_engine.jax.flax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs)
.. autoapiclass:: transformer_engine.jax.flax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)
.. autoapiclass:: transformer_engine.jax.flax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)
.. autoapiclass:: transformer_engine.jax.flax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)
.. autoapiclass:: transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.MultiHeadAttention(head_dim, num_heads, **kwargs)
.. autoapiclass:: transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs)
.. autoapiclass:: transformer_engine.jax.flax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs)
:members: __call__
.. autoapifunction:: transformer_engine.jax.extend_logical_axis_rules
.. autoapifunction:: transformer_engine.jax.fp8_autocast
.. autoapifunction:: transformer_engine.jax.update_collections
.. autoapifunction:: transformer_engine.jax.update_fp8_metas
\ No newline at end of file
.. autoapifunction:: transformer_engine.jax.flax.extend_logical_axis_rules
......@@ -59,7 +59,7 @@ class Net(nn.Module):
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.TransformerLayer,
te_Encoder = partial(te.flax.TransformerLayer,
hidden_size=256,
mlp_hidden_size=1024,
num_attention_heads=8,
......@@ -73,17 +73,17 @@ class Net(nn.Module):
x = x.reshape(x.shape[0], -1)
x = te.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
sharding_type=te.ShardingType.DP_TP_COL,
dtype=jnp.bfloat16)(x)
x = te.DenseGeneral(features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,),
sharding_type=te.ShardingType.DP_TP_ROW,
dtype=jnp.bfloat16)(x)
x = te.flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
sharding_type=te.ShardingType.DP_TP_COL,
dtype=jnp.bfloat16)(x)
x = te.flax.DenseGeneral(features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,),
sharding_type=te.ShardingType.DP_TP_ROW,
dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x
......
......@@ -56,7 +56,7 @@ class Net(nn.Module):
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.TransformerLayer,
te_Encoder = partial(te.flax.TransformerLayer,
hidden_size=256,
mlp_hidden_size=1024,
num_attention_heads=8,
......@@ -70,9 +70,11 @@ class Net(nn.Module):
x = x.reshape(x.shape[0], -1)
x = te.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, dtype=jnp.bfloat16)(x)
x = te.flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP,
dtype=jnp.bfloat16)(x)
x = te.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, dtype=jnp.bfloat16)(x)
x = te.flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP,
dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x
......
......@@ -46,7 +46,7 @@ class Net(nn.Module):
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.TransformerLayer,
te_Encoder = partial(te.flax.TransformerLayer,
hidden_size=256,
mlp_hidden_size=1024,
num_attention_heads=8,
......@@ -60,9 +60,9 @@ class Net(nn.Module):
x = x.reshape(x.shape[0], -1)
x = te.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te.flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te.flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x
......
......@@ -47,7 +47,7 @@ class Net(nn.Module):
@nn.compact
def __call__(self, x, disable_dropout=False):
if self.use_te:
nn_Dense = te.DenseGeneral
nn_Dense = te.flax.DenseGeneral
else:
nn_Dense = nn.Dense
......
......@@ -10,7 +10,7 @@ import jax.numpy as jnp
import pytest
from transformer_engine.common.recipe import Format
from transformer_engine.jax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper
from utils import assert_allclose, is_fp8_supported
from utils import DecoderLayer as RefDecoderLayer
......
......@@ -7,7 +7,7 @@ import numpy as np
import pytest
from jax.experimental import maps
from transformer_engine.jax import extend_logical_axis_rules
from transformer_engine.jax.flax import extend_logical_axis_rules
from transformer_engine.jax.sharding import get_dot_sharding_meta
from transformer_engine.jax.sharding import get_elementwise_sharding_meta
from transformer_engine.jax.sharding import get_fp8_meta_sharding_meta
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""The utilities for Transformer Engine"""
import inspect
import warnings
from enum import Enum
warnings.simplefilter('default')
class DeprecatedEnum: # pylint: disable=too-few-public-methods
"""DeprecatedEnum"""
def __init__(self, enum_cls, msg):
self.enum_cls = enum_cls
self.msg = msg
def __iter__(self):
return iter(list(self.enum_cls.__members__.values()))
def __getattr__(self, name):
if name in self.enum_cls.__members__:
warnings.warn(self.msg, DeprecationWarning)
return self.enum_cls.__members__[name]
raise AttributeError(f"{self.enum_cls} does not contain {name}")
def deprecate_wrapper(obj, msg):
"""Deprecate wrapper"""
if inspect.isclass(obj):
if issubclass(obj, Enum):
return DeprecatedEnum(obj, msg)
class DeprecatedCls(obj): # pylint: disable=too-few-public-methods
"""DeprecatedCls"""
def __init__(self, *args, **kwargs):
warnings.warn(msg, DeprecationWarning)
super().__init__(*args, **kwargs)
return DeprecatedCls
if inspect.isfunction(obj):
def deprecated(*args, **kwargs):
warnings.warn(msg, DeprecationWarning)
return obj(*args, **kwargs)
return deprecated
raise NotImplementedError(
f"deprecate_cls_wrapper only support Class and Function, but got {type(obj)}.")
......@@ -2,10 +2,41 @@
#
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
from . import flax
from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling
from .module import DenseGeneral, LayerNorm
from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .transformer import extend_logical_axis_rules
from .transformer import MultiHeadAttention, RelativePositionBiases
from .transformer import TransformerLayer, TransformerLayerType
from .sharding import MajorShardingType, ShardingResource, ShardingType
from ..common.utils import deprecate_wrapper
extend_logical_axis_rules = deprecate_wrapper(
flax.extend_logical_axis_rules,
"extend_logical_axis_rules is moving to transformer_engine.jax.flax module")
DenseGeneral = deprecate_wrapper(flax.DenseGeneral,
"DenseGeneral is moving to transformer_engine.jax.flax module")
LayerNorm = deprecate_wrapper(flax.LayerNorm,
"LayerNorm is moving to transformer_engine.jax.flax module")
LayerNormDenseGeneral = deprecate_wrapper(
flax.LayerNormDenseGeneral,
"LayerNormDenseGeneral is moving to transformer_engine.jax.flax module")
LayerNormMLP = deprecate_wrapper(flax.LayerNormMLP,
"LayerNormMLP is moving to transformer_engine.jax.flax module")
TransformerEngineBase = deprecate_wrapper(
flax.TransformerEngineBase,
"TransformerEngineBase is moving to transformer_engine.jax.flax module")
MultiHeadAttention = deprecate_wrapper(
flax.MultiHeadAttention, "MultiHeadAttention is moving to transformer_engine.jax.flax module")
RelativePositionBiases = deprecate_wrapper(
flax.RelativePositionBiases,
"RelativePositionBiases is moving to transformer_engine.jax.flax module")
TransformerLayer = deprecate_wrapper(
flax.TransformerLayer, "TransformerLayer is moving to transformer_engine.jax.flax module")
TransformerLayerType = deprecate_wrapper(
flax.TransformerLayerType,
"TransformerLayerType is moving to transformer_engine.jax.flax module")
__all__ = [
'fp8_autocast', 'update_collections', 'update_fp8_metas', 'get_delayed_scaling',
'MajorShardingType', 'ShardingResource', 'ShardingType', 'flax', 'DenseGeneral', 'LayerNorm',
'LayerNormDenseGeneral', 'LayerNormMLP', 'TransformerEngineBase', 'MultiHeadAttention',
'RelativePositionBiases', 'TransformerLayer', 'TransformerLayerType'
]
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
from .module import DenseGeneral, LayerNorm
from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .transformer import extend_logical_axis_rules
from .transformer import MultiHeadAttention, RelativePositionBiases
from .transformer import TransformerLayer, TransformerLayerType
......@@ -16,15 +16,15 @@ from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random
from .dot import fp8_dot
from .fp8 import FP8GemmPackage, FP8Helper
from .layernorm import canonicalize_layernorm_type
from .layernorm import layernorm, layernorm_fp8_dot
from .mlp import fp8_ln_mlp, geglu
from .sharding import infer_sharding_type
from .softmax import is_softmax_kernel_available
from .sharding import MajorShardingType, ShardingType
from .softmax import softmax, SoftmaxType
from ..dot import fp8_dot
from ..fp8 import FP8GemmPackage, FP8Helper
from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot
from ..mlp import fp8_ln_mlp, geglu
from ..sharding import infer_sharding_type
from ..softmax import is_softmax_kernel_available
from ..sharding import MajorShardingType, ShardingType
from ..softmax import softmax, SoftmaxType
PRNGKey = Any
Shape = Tuple[int, ...]
......@@ -46,6 +46,13 @@ def _canonicalize_tuple(x):
return (x,)
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
if original_init is None:
if not zero_centered_gamma:
return nn.initializers.ones
return nn.initializers.zeros
def _create_layernorm_parameters(layernorm_type, shape, scale_init, scale_axes, bias_init,
bias_axes, dtype):
scale = nn_partitioning.param_with_axes('scale',
......@@ -250,11 +257,8 @@ class LayerNorm(nn.Module):
sharding_type: ShardingType = ShardingType.SINGLE
def __post_init__(self):
if self.scale_init is None:
if not self.zero_centered_gamma:
self.scale_init = nn.initializers.ones
else:
self.scale_init = nn.initializers.zeros
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init, self.zero_centered_gamma)
super().__post_init__()
@nn.compact
......@@ -551,11 +555,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
if self.scale_init is None:
if not self.zero_centered_gamma:
self.scale_init = nn.initializers.ones
else:
self.scale_init = nn.initializers.zeros
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init, self.zero_centered_gamma)
super().__post_init__()
@nn.compact
......@@ -785,11 +786,8 @@ class LayerNormMLP(TransformerEngineBase):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
if self.scale_init is None:
if not self.zero_centered_gamma:
self.scale_init = nn.initializers.ones
else:
self.scale_init = nn.initializers.zeros
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init, self.zero_centered_gamma)
super().__post_init__()
@nn.compact
......
......@@ -18,9 +18,9 @@ from jax import lax, vmap
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
from .softmax import SoftmaxType
from .sharding import infer_major_sharding_type, infer_sharding_type
from .sharding import global_shard_resource, ShardingType
from ..softmax import SoftmaxType
from ..sharding import infer_major_sharding_type, infer_sharding_type
from ..sharding import global_shard_resource, ShardingType
PRNGKey = Any
Shape = Tuple[int, ...]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment