Unverified Commit 0792ded4 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Adjust Module Structure. (#169)



* Adjust Module Structure.

1. Collect Flax related modules to a sub-folder, flax.
2. Add a function to unify scale_init for zero-centered-gamma LN.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Make changes be compatible to previous versions.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adapt jax/examples to the new module structure.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Update jax/docs and Add deprecated warning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Update README
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding deprecated_wrapper
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding deprecated warning to flax modules which imported via transformer_engine.jax
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix CI errors and update docs.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Removing unnecessary deprecated warning in docs.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Implementing __iter__ to DeprecatedEnum.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 522fecc1
...@@ -69,6 +69,9 @@ pyTorch ...@@ -69,6 +69,9 @@ pyTorch
JAX JAX
^^^ ^^^
Flax
~~~~
.. code-block:: python .. code-block:: python
import jax import jax
...@@ -90,7 +93,7 @@ JAX ...@@ -90,7 +93,7 @@ JAX
# Enable autocasting for the forward pass # Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
model = te.DenseGeneral(features=HIDDEN) model = te.flax.DenseGeneral(features=HIDDEN)
def loss_fn(params, other_vars, inp): def loss_fn(params, other_vars, inp):
out = model.apply({'params':params, **other_vars}, inp) out = model.apply({'params':params, **other_vars}, inp)
......
...@@ -9,34 +9,33 @@ Jax ...@@ -9,34 +9,33 @@ Jax
.. autoapiclass:: transformer_engine.jax.MajorShardingType .. autoapiclass:: transformer_engine.jax.MajorShardingType
.. autoapiclass:: transformer_engine.jax.ShardingType .. autoapiclass:: transformer_engine.jax.ShardingType
.. autoapiclass:: transformer_engine.jax.TransformerLayerType .. autoapiclass:: transformer_engine.jax.TransformerLayerType
.. autoapiclass:: transformer_engine.jax.ShardingResource(dp_resource=None, tp_resource=None)
.. autoapiclass:: transformer_engine.jax.ShardingResource(dp_resource=None, tp_resource=None) .. autoapifunction:: transformer_engine.jax.fp8_autocast
.. autoapifunction:: transformer_engine.jax.update_collections
.. autoapifunction:: transformer_engine.jax.update_fp8_metas
.. autoapiclass:: transformer_engine.jax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs) .. autoapiclass:: transformer_engine.jax.flax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs)
:members: __call__ :members: __call__
.. autoapiclass:: transformer_engine.jax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs) .. autoapiclass:: transformer_engine.jax.flax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs)
:members: __call__ :members: __call__
.. autoapiclass:: transformer_engine.jax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs) .. autoapiclass:: transformer_engine.jax.flax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)
:members: __call__ :members: __call__
.. autoapiclass:: transformer_engine.jax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs) .. autoapiclass:: transformer_engine.jax.flax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)
:members: __call__ :members: __call__
.. autoapiclass:: transformer_engine.jax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs) .. autoapiclass:: transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)
:members: __call__ :members: __call__
.. autoapiclass:: transformer_engine.jax.MultiHeadAttention(head_dim, num_heads, **kwargs) .. autoapiclass:: transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs)
:members: __call__ :members: __call__
.. autoapiclass:: transformer_engine.jax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs) .. autoapiclass:: transformer_engine.jax.flax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs)
:members: __call__ :members: __call__
.. autoapifunction:: transformer_engine.jax.flax.extend_logical_axis_rules
.. autoapifunction:: transformer_engine.jax.extend_logical_axis_rules
.. autoapifunction:: transformer_engine.jax.fp8_autocast
.. autoapifunction:: transformer_engine.jax.update_collections
.. autoapifunction:: transformer_engine.jax.update_fp8_metas
\ No newline at end of file
...@@ -59,7 +59,7 @@ class Net(nn.Module): ...@@ -59,7 +59,7 @@ class Net(nn.Module):
def __call__(self, x, mask, disable_dropout=False): def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.TransformerLayer, te_Encoder = partial(te.flax.TransformerLayer,
hidden_size=256, hidden_size=256,
mlp_hidden_size=1024, mlp_hidden_size=1024,
num_attention_heads=8, num_attention_heads=8,
...@@ -73,17 +73,17 @@ class Net(nn.Module): ...@@ -73,17 +73,17 @@ class Net(nn.Module):
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
x = te.DenseGeneral(features=256, x = te.flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,), bias_axes=(NAMED_TP_AXIS,),
sharding_type=te.ShardingType.DP_TP_COL, sharding_type=te.ShardingType.DP_TP_COL,
dtype=jnp.bfloat16)(x) dtype=jnp.bfloat16)(x)
x = te.DenseGeneral(features=256, x = te.flax.DenseGeneral(features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,), bias_axes=(NAMED_BROADCAST_AXIS,),
sharding_type=te.ShardingType.DP_TP_ROW, sharding_type=te.ShardingType.DP_TP_ROW,
dtype=jnp.bfloat16)(x) dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x return x
......
...@@ -56,7 +56,7 @@ class Net(nn.Module): ...@@ -56,7 +56,7 @@ class Net(nn.Module):
def __call__(self, x, mask, disable_dropout=False): def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.TransformerLayer, te_Encoder = partial(te.flax.TransformerLayer,
hidden_size=256, hidden_size=256,
mlp_hidden_size=1024, mlp_hidden_size=1024,
num_attention_heads=8, num_attention_heads=8,
...@@ -70,9 +70,11 @@ class Net(nn.Module): ...@@ -70,9 +70,11 @@ class Net(nn.Module):
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
x = te.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, dtype=jnp.bfloat16)(x) x = te.flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP,
dtype=jnp.bfloat16)(x)
x = te.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, dtype=jnp.bfloat16)(x) x = te.flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP,
dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x return x
......
...@@ -46,7 +46,7 @@ class Net(nn.Module): ...@@ -46,7 +46,7 @@ class Net(nn.Module):
def __call__(self, x, mask, disable_dropout=False): def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.TransformerLayer, te_Encoder = partial(te.flax.TransformerLayer,
hidden_size=256, hidden_size=256,
mlp_hidden_size=1024, mlp_hidden_size=1024,
num_attention_heads=8, num_attention_heads=8,
...@@ -60,9 +60,9 @@ class Net(nn.Module): ...@@ -60,9 +60,9 @@ class Net(nn.Module):
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
x = te.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) x = te.flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) x = te.flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x return x
......
...@@ -47,7 +47,7 @@ class Net(nn.Module): ...@@ -47,7 +47,7 @@ class Net(nn.Module):
@nn.compact @nn.compact
def __call__(self, x, disable_dropout=False): def __call__(self, x, disable_dropout=False):
if self.use_te: if self.use_te:
nn_Dense = te.DenseGeneral nn_Dense = te.flax.DenseGeneral
else: else:
nn_Dense = nn.Dense nn_Dense = nn.Dense
......
...@@ -10,7 +10,7 @@ import jax.numpy as jnp ...@@ -10,7 +10,7 @@ import jax.numpy as jnp
import pytest import pytest
from transformer_engine.common.recipe import Format from transformer_engine.common.recipe import Format
from transformer_engine.jax import TransformerLayer, TransformerLayerType from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper from transformer_engine.jax.fp8 import FP8Helper
from utils import assert_allclose, is_fp8_supported from utils import assert_allclose, is_fp8_supported
from utils import DecoderLayer as RefDecoderLayer from utils import DecoderLayer as RefDecoderLayer
......
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
import pytest import pytest
from jax.experimental import maps from jax.experimental import maps
from transformer_engine.jax import extend_logical_axis_rules from transformer_engine.jax.flax import extend_logical_axis_rules
from transformer_engine.jax.sharding import get_dot_sharding_meta from transformer_engine.jax.sharding import get_dot_sharding_meta
from transformer_engine.jax.sharding import get_elementwise_sharding_meta from transformer_engine.jax.sharding import get_elementwise_sharding_meta
from transformer_engine.jax.sharding import get_fp8_meta_sharding_meta from transformer_engine.jax.sharding import get_fp8_meta_sharding_meta
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""The utilities for Transformer Engine"""
import inspect
import warnings
from enum import Enum
warnings.simplefilter('default')
class DeprecatedEnum: # pylint: disable=too-few-public-methods
"""DeprecatedEnum"""
def __init__(self, enum_cls, msg):
self.enum_cls = enum_cls
self.msg = msg
def __iter__(self):
return iter(list(self.enum_cls.__members__.values()))
def __getattr__(self, name):
if name in self.enum_cls.__members__:
warnings.warn(self.msg, DeprecationWarning)
return self.enum_cls.__members__[name]
raise AttributeError(f"{self.enum_cls} does not contain {name}")
def deprecate_wrapper(obj, msg):
"""Deprecate wrapper"""
if inspect.isclass(obj):
if issubclass(obj, Enum):
return DeprecatedEnum(obj, msg)
class DeprecatedCls(obj): # pylint: disable=too-few-public-methods
"""DeprecatedCls"""
def __init__(self, *args, **kwargs):
warnings.warn(msg, DeprecationWarning)
super().__init__(*args, **kwargs)
return DeprecatedCls
if inspect.isfunction(obj):
def deprecated(*args, **kwargs):
warnings.warn(msg, DeprecationWarning)
return obj(*args, **kwargs)
return deprecated
raise NotImplementedError(
f"deprecate_cls_wrapper only support Class and Function, but got {type(obj)}.")
...@@ -2,10 +2,41 @@ ...@@ -2,10 +2,41 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for JAX""" """Transformer Engine bindings for JAX"""
from . import flax
from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling
from .module import DenseGeneral, LayerNorm
from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .transformer import extend_logical_axis_rules
from .transformer import MultiHeadAttention, RelativePositionBiases
from .transformer import TransformerLayer, TransformerLayerType
from .sharding import MajorShardingType, ShardingResource, ShardingType from .sharding import MajorShardingType, ShardingResource, ShardingType
from ..common.utils import deprecate_wrapper
extend_logical_axis_rules = deprecate_wrapper(
flax.extend_logical_axis_rules,
"extend_logical_axis_rules is moving to transformer_engine.jax.flax module")
DenseGeneral = deprecate_wrapper(flax.DenseGeneral,
"DenseGeneral is moving to transformer_engine.jax.flax module")
LayerNorm = deprecate_wrapper(flax.LayerNorm,
"LayerNorm is moving to transformer_engine.jax.flax module")
LayerNormDenseGeneral = deprecate_wrapper(
flax.LayerNormDenseGeneral,
"LayerNormDenseGeneral is moving to transformer_engine.jax.flax module")
LayerNormMLP = deprecate_wrapper(flax.LayerNormMLP,
"LayerNormMLP is moving to transformer_engine.jax.flax module")
TransformerEngineBase = deprecate_wrapper(
flax.TransformerEngineBase,
"TransformerEngineBase is moving to transformer_engine.jax.flax module")
MultiHeadAttention = deprecate_wrapper(
flax.MultiHeadAttention, "MultiHeadAttention is moving to transformer_engine.jax.flax module")
RelativePositionBiases = deprecate_wrapper(
flax.RelativePositionBiases,
"RelativePositionBiases is moving to transformer_engine.jax.flax module")
TransformerLayer = deprecate_wrapper(
flax.TransformerLayer, "TransformerLayer is moving to transformer_engine.jax.flax module")
TransformerLayerType = deprecate_wrapper(
flax.TransformerLayerType,
"TransformerLayerType is moving to transformer_engine.jax.flax module")
__all__ = [
'fp8_autocast', 'update_collections', 'update_fp8_metas', 'get_delayed_scaling',
'MajorShardingType', 'ShardingResource', 'ShardingType', 'flax', 'DenseGeneral', 'LayerNorm',
'LayerNormDenseGeneral', 'LayerNormMLP', 'TransformerEngineBase', 'MultiHeadAttention',
'RelativePositionBiases', 'TransformerLayer', 'TransformerLayerType'
]
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
from .module import DenseGeneral, LayerNorm
from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .transformer import extend_logical_axis_rules
from .transformer import MultiHeadAttention, RelativePositionBiases
from .transformer import TransformerLayer, TransformerLayerType
...@@ -16,15 +16,15 @@ from jax import lax ...@@ -16,15 +16,15 @@ from jax import lax
from jax import nn as jax_nn from jax import nn as jax_nn
from jax import random as jax_random from jax import random as jax_random
from .dot import fp8_dot from ..dot import fp8_dot
from .fp8 import FP8GemmPackage, FP8Helper from ..fp8 import FP8GemmPackage, FP8Helper
from .layernorm import canonicalize_layernorm_type from ..layernorm import canonicalize_layernorm_type
from .layernorm import layernorm, layernorm_fp8_dot from ..layernorm import layernorm, layernorm_fp8_dot
from .mlp import fp8_ln_mlp, geglu from ..mlp import fp8_ln_mlp, geglu
from .sharding import infer_sharding_type from ..sharding import infer_sharding_type
from .softmax import is_softmax_kernel_available from ..softmax import is_softmax_kernel_available
from .sharding import MajorShardingType, ShardingType from ..sharding import MajorShardingType, ShardingType
from .softmax import softmax, SoftmaxType from ..softmax import softmax, SoftmaxType
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
...@@ -46,6 +46,13 @@ def _canonicalize_tuple(x): ...@@ -46,6 +46,13 @@ def _canonicalize_tuple(x):
return (x,) return (x,)
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
if original_init is None:
if not zero_centered_gamma:
return nn.initializers.ones
return nn.initializers.zeros
def _create_layernorm_parameters(layernorm_type, shape, scale_init, scale_axes, bias_init, def _create_layernorm_parameters(layernorm_type, shape, scale_init, scale_axes, bias_init,
bias_axes, dtype): bias_axes, dtype):
scale = nn_partitioning.param_with_axes('scale', scale = nn_partitioning.param_with_axes('scale',
...@@ -250,11 +257,8 @@ class LayerNorm(nn.Module): ...@@ -250,11 +257,8 @@ class LayerNorm(nn.Module):
sharding_type: ShardingType = ShardingType.SINGLE sharding_type: ShardingType = ShardingType.SINGLE
def __post_init__(self): def __post_init__(self):
if self.scale_init is None: self.scale_init = _obtain_default_layernorm_scale_init_if_need(
if not self.zero_centered_gamma: self.scale_init, self.zero_centered_gamma)
self.scale_init = nn.initializers.ones
else:
self.scale_init = nn.initializers.zeros
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -551,11 +555,8 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -551,11 +555,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
if self.scale_init is None: self.scale_init = _obtain_default_layernorm_scale_init_if_need(
if not self.zero_centered_gamma: self.scale_init, self.zero_centered_gamma)
self.scale_init = nn.initializers.ones
else:
self.scale_init = nn.initializers.zeros
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -785,11 +786,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -785,11 +786,8 @@ class LayerNormMLP(TransformerEngineBase):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
if self.scale_init is None: self.scale_init = _obtain_default_layernorm_scale_init_if_need(
if not self.zero_centered_gamma: self.scale_init, self.zero_centered_gamma)
self.scale_init = nn.initializers.ones
else:
self.scale_init = nn.initializers.zeros
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
......
...@@ -18,9 +18,9 @@ from jax import lax, vmap ...@@ -18,9 +18,9 @@ from jax import lax, vmap
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax from .module import LayerNorm, Softmax
from .softmax import SoftmaxType from ..softmax import SoftmaxType
from .sharding import infer_major_sharding_type, infer_sharding_type from ..sharding import infer_major_sharding_type, infer_sharding_type
from .sharding import global_shard_resource, ShardingType from ..sharding import global_shard_resource, ShardingType
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment