Unverified Commit b20c0531 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

The Implementation of Praxis's Modules (#158)



* Adding JAX/Praxis modules and dependencies.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding UTs to JAX/Praxis modules.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Remove praxis as a dependency due to not strictly needed
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Repalce is_fp8_supported to is_fp8_available
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Make Praxis as an optional dependency.

1. Removed 'from . import praxis' in __init__.py.
    1.1 Noted, keep 'from . import flax' for deprecated warning.
2. Changed te.flax to te_flax in examples and README.rst.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding a workaround to FP8 training on Praxis.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent 68f60b89
...@@ -77,6 +77,7 @@ Flax ...@@ -77,6 +77,7 @@ Flax
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.common import recipe from transformer_engine.common import recipe
BATCH = 32 BATCH = 32
...@@ -93,7 +94,7 @@ Flax ...@@ -93,7 +94,7 @@ Flax
# Enable autocasting for the forward pass # Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
model = te.flax.DenseGeneral(features=HIDDEN) model = te_flax.DenseGeneral(features=HIDDEN)
def loss_fn(params, other_vars, inp): def loss_fn(params, other_vars, inp):
out = model.apply({'params':params, **other_vars}, inp) out = model.apply({'params':params, **other_vars}, inp)
......
...@@ -20,6 +20,7 @@ from jax.experimental import mesh_utils ...@@ -20,6 +20,7 @@ from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit from jax.experimental.pjit import pjit
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
DEVICE_DP_AXIS = 'data' DEVICE_DP_AXIS = 'data'
DEVICE_TP_AXIS = 'model' DEVICE_TP_AXIS = 'model'
...@@ -39,27 +40,27 @@ class Net(nn.Module): ...@@ -39,27 +40,27 @@ class Net(nn.Module):
def __call__(self, x, mask, disable_dropout=False): def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.flax.TransformerLayer, te_Encoder = partial(te_flax.TransformerLayer,
hidden_size=256, hidden_size=256,
mlp_hidden_size=1024, mlp_hidden_size=1024,
num_attention_heads=8, num_attention_heads=8,
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, dropout_rng_name=DROPOUT_KEY,
layer_type=te.flax.TransformerLayerType.ENCODER, layer_type=te_flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
x = te.flax.DenseGeneral(features=256, x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,), bias_axes=(NAMED_TP_AXIS,),
sharding_type=te.ShardingType.DP_TP_COL, sharding_type=te.ShardingType.DP_TP_COL,
dtype=jnp.bfloat16)(x) dtype=jnp.bfloat16)(x)
x = te.flax.DenseGeneral(features=256, x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,), bias_axes=(NAMED_BROADCAST_AXIS,),
sharding_type=te.ShardingType.DP_TP_ROW, sharding_type=te.ShardingType.DP_TP_ROW,
...@@ -174,9 +175,7 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): ...@@ -174,9 +175,7 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
else: else:
tensor[i] = vocab[word] tensor[i] = vocab[word]
seq_len = len(tokens) seq_len = min(len(tokens), max_seq_len)
if seq_len > max_seq_len:
seq_len = max_seq_len
mask_2d = mask_3d[j] mask_2d = mask_3d[j]
mask_2d[:seq_len, :seq_len] = 0 mask_2d[:seq_len, :seq_len] = 0
...@@ -275,7 +274,7 @@ def train_and_evaluate(args): ...@@ -275,7 +274,7 @@ def train_and_evaluate(args):
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te.flax.extend_logical_axis_rules(tuple()) + customized_rules sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
params_pspec = get_params_pspec(sharding_rules, abs_var_collect) params_pspec = get_params_pspec(sharding_rules, abs_var_collect)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
......
...@@ -20,6 +20,7 @@ from jax.experimental import mesh_utils ...@@ -20,6 +20,7 @@ from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit from jax.experimental.pjit import pjit
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
DEVICE_DP_AXIS = 'data' DEVICE_DP_AXIS = 'data'
PARAMS_KEY = 'params' PARAMS_KEY = 'params'
...@@ -36,24 +37,24 @@ class Net(nn.Module): ...@@ -36,24 +37,24 @@ class Net(nn.Module):
def __call__(self, x, mask, disable_dropout=False): def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.flax.TransformerLayer, te_Encoder = partial(te_flax.TransformerLayer,
hidden_size=256, hidden_size=256,
mlp_hidden_size=1024, mlp_hidden_size=1024,
num_attention_heads=8, num_attention_heads=8,
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, dropout_rng_name=DROPOUT_KEY,
layer_type=te.flax.TransformerLayerType.ENCODER, layer_type=te_flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
x = te.flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, x = te_flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP,
dtype=jnp.bfloat16)(x) dtype=jnp.bfloat16)(x)
x = te.flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, x = te_flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP,
dtype=jnp.bfloat16)(x) dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
...@@ -165,9 +166,7 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): ...@@ -165,9 +166,7 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
else: else:
tensor[i] = vocab[word] tensor[i] = vocab[word]
seq_len = len(tokens) seq_len = min(len(tokens), max_seq_len)
if seq_len > max_seq_len:
seq_len = max_seq_len
mask_2d = mask_3d[j] mask_2d = mask_3d[j]
mask_2d[:seq_len, :seq_len] = 0 mask_2d[:seq_len, :seq_len] = 0
...@@ -257,7 +256,7 @@ def train_and_evaluate(args): ...@@ -257,7 +256,7 @@ def train_and_evaluate(args):
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
sharding_rules = te.flax.extend_logical_axis_rules(tuple()) sharding_rules = te_flax.extend_logical_axis_rules(tuple())
params_pspec = get_params_pspec(sharding_rules, abs_var_collect) params_pspec = get_params_pspec(sharding_rules, abs_var_collect)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
......
...@@ -22,6 +22,7 @@ from jax.experimental import mesh_utils ...@@ -22,6 +22,7 @@ from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit from jax.experimental.pjit import pjit
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
DEVICE_DP_AXIS = 'data' DEVICE_DP_AXIS = 'data'
...@@ -42,27 +43,27 @@ class Net(nn.Module): ...@@ -42,27 +43,27 @@ class Net(nn.Module):
def __call__(self, x, mask, disable_dropout=False): def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.flax.TransformerLayer, te_Encoder = partial(te_flax.TransformerLayer,
hidden_size=256, hidden_size=256,
mlp_hidden_size=1024, mlp_hidden_size=1024,
num_attention_heads=8, num_attention_heads=8,
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, dropout_rng_name=DROPOUT_KEY,
layer_type=te.flax.TransformerLayerType.ENCODER, layer_type=te_flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
x = te.flax.DenseGeneral(features=256, x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,), bias_axes=(NAMED_TP_AXIS,),
sharding_type=te.ShardingType.DP_TP_COL, sharding_type=te.ShardingType.DP_TP_COL,
dtype=jnp.bfloat16)(x) dtype=jnp.bfloat16)(x)
x = te.flax.DenseGeneral(features=256, x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,), bias_axes=(NAMED_BROADCAST_AXIS,),
sharding_type=te.ShardingType.DP_TP_ROW, sharding_type=te.ShardingType.DP_TP_ROW,
...@@ -248,9 +249,7 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): ...@@ -248,9 +249,7 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
else: else:
tensor[i] = vocab[word] tensor[i] = vocab[word]
seq_len = len(tokens) seq_len = min(len(tokens), max_seq_len)
if seq_len > max_seq_len:
seq_len = max_seq_len
mask_2d = mask_3d[j] mask_2d = mask_3d[j]
mask_2d[:seq_len, :seq_len] = 0 mask_2d[:seq_len, :seq_len] = 0
...@@ -356,7 +355,7 @@ def train_and_evaluate(args): ...@@ -356,7 +355,7 @@ def train_and_evaluate(args):
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te.flax.extend_logical_axis_rules(tuple()) + customized_rules sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
params_pspec = get_params_pspec(sharding_rules, abs_var_collect) params_pspec = get_params_pspec(sharding_rules, abs_var_collect)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
......
...@@ -17,6 +17,7 @@ from flax.core.frozen_dict import FrozenDict ...@@ -17,6 +17,7 @@ from flax.core.frozen_dict import FrozenDict
from flax.training import train_state from flax.training import train_state
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
PARAMS_KEY = 'params' PARAMS_KEY = 'params'
DROPOUT_KEY = 'dropout' DROPOUT_KEY = 'dropout'
...@@ -31,23 +32,23 @@ class Net(nn.Module): ...@@ -31,23 +32,23 @@ class Net(nn.Module):
def __call__(self, x, mask, disable_dropout=False): def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.flax.TransformerLayer, te_Encoder = partial(te_flax.TransformerLayer,
hidden_size=256, hidden_size=256,
mlp_hidden_size=1024, mlp_hidden_size=1024,
num_attention_heads=8, num_attention_heads=8,
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, dropout_rng_name=DROPOUT_KEY,
layer_type=te.flax.TransformerLayerType.ENCODER, layer_type=te_flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
x = te.flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te.flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x return x
...@@ -160,9 +161,7 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): ...@@ -160,9 +161,7 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
else: else:
tensor[i] = vocab[word] tensor[i] = vocab[word]
seq_len = len(tokens) seq_len = min(len(tokens), max_seq_len)
if seq_len > max_seq_len:
seq_len = max_seq_len
mask_2d = mask_3d[j] mask_2d = mask_3d[j]
mask_2d[:seq_len, :seq_len] = 0 mask_2d[:seq_len, :seq_len] = 0
......
...@@ -16,6 +16,7 @@ from flax.core.frozen_dict import FrozenDict ...@@ -16,6 +16,7 @@ from flax.core.frozen_dict import FrozenDict
from flax.training import train_state from flax.training import train_state
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
IMAGE_H = 28 IMAGE_H = 28
IMAGE_W = 28 IMAGE_W = 28
...@@ -32,7 +33,7 @@ class Net(nn.Module): ...@@ -32,7 +33,7 @@ class Net(nn.Module):
@nn.compact @nn.compact
def __call__(self, x, disable_dropout=False): def __call__(self, x, disable_dropout=False):
if self.use_te: if self.use_te:
nn_Dense = te.flax.DenseGeneral nn_Dense = te_flax.DenseGeneral
else: else:
nn_Dense = nn.Dense nn_Dense = nn.Dense
......
This diff is collapsed.
...@@ -36,7 +36,7 @@ TransformerLayerType = deprecate_wrapper( ...@@ -36,7 +36,7 @@ TransformerLayerType = deprecate_wrapper(
__all__ = [ __all__ = [
'fp8_autocast', 'update_collections', 'update_fp8_metas', 'get_delayed_scaling', 'fp8_autocast', 'update_collections', 'update_fp8_metas', 'get_delayed_scaling',
'MajorShardingType', 'ShardingResource', 'ShardingType', 'flax', 'DenseGeneral', 'LayerNorm', 'MajorShardingType', 'ShardingResource', 'ShardingType', 'flax', 'praxis', 'DenseGeneral',
'LayerNormDenseGeneral', 'LayerNormMLP', 'TransformerEngineBase', 'MultiHeadAttention', 'LayerNorm', 'LayerNormDenseGeneral', 'LayerNormMLP', 'TransformerEngineBase',
'RelativePositionBiases', 'TransformerLayer', 'TransformerLayerType' 'MultiHeadAttention', 'RelativePositionBiases', 'TransformerLayer', 'TransformerLayerType'
] ]
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Praxis related Modules"""
from .module import FusedSoftmax, LayerNorm
from .module import LayerNormLinear, LayerNormMLP, Linear, TransformerEngineBaseLayer
from .transformer import MultiHeadAttention, RelativePositionBiases, TransformerLayer
from ..flax.transformer import TransformerLayerType
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Praxis Modules
"""
from functools import partial
from typing import Callable, Iterable, Sequence, Tuple, Union
from praxis import pax_fiddle
from praxis.base_layer import init_var
from praxis.base_layer import BaseLayer, WeightInit, WeightHParams
from praxis.layers import flax_adapter
from praxis.pytypes import JTensor
from ..flax.module import DenseGeneral, LayerNormDenseGeneral
from ..flax.module import LayerNorm as flax_LayerNorm
from ..flax.module import LayerNormMLP as flax_LayerNormMLP
from ..flax.module import Softmax
from ..softmax import SoftmaxType
from ..sharding import MajorShardingType, ShardingType
def _generate_ln_scale_init(scale_init):
if scale_init is not None:
return TransformerEngineBaseLayer.generate_params_init("scale", scale_init)
return scale_init
class TransformerEngineBaseLayer(BaseLayer):
"""TransformerEngineBaseLayer"""
logical_axes_rules: Tuple[Tuple, ...] = None
@staticmethod
def generate_params_init(name: str, initializer: WeightInit):
"""generate_params_init"""
def kernel_init(key, shape, dtype):
wp = WeightHParams(shape=shape, init=initializer, dtype=dtype)
return init_var(wp, key, name)
return kernel_init
def create_layer(self, name, flax_module_cls):
"""create_layer"""
flax_module_p = pax_fiddle.Config(flax_adapter.FlaxModuleAdapter,
module_factory_method=flax_module_cls,
logical_axes_rules=self.logical_axes_rules,
ici_mesh_shape=self.ici_mesh_shape,
dcn_mesh_shape=self.dcn_mesh_shape,
mesh_axis_names=self.mesh_axis_names)
self.create_child(name, flax_module_p.clone())
class LayerNorm(TransformerEngineBaseLayer):
"""LayerNorm"""
epsilon: float = 1e-6
layernorm_type: str = 'layernorm'
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False
sharding_type: ShardingType = ShardingType.SINGLE
def setup(self) -> None:
"""setup"""
super().setup()
ln_cls = partial(flax_LayerNorm,
epsilon=self.epsilon,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
bias_init=TransformerEngineBaseLayer.generate_params_init(
"ln_bias", self.bias_init),
bias_axes=self.bias_axes,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
sharding_type=self.sharding_type)
self.create_layer("layer_norm", ln_cls)
def __call__(self, x: JTensor) -> JTensor:
"""__call__"""
return self.layer_norm(x)
class FusedSoftmax(TransformerEngineBaseLayer):
"""FusedSoftmax"""
scale_factor: float = 1.0
softmax_type: SoftmaxType = SoftmaxType.SCALED
sharding_type: ShardingType = ShardingType.SINGLE
def setup(self) -> None:
"""setup"""
super().setup()
fused_softmax_cls = partial(Softmax,
scale_factor=self.scale_factor,
softmax_type=self.softmax_type,
sharding_type=self.sharding_type)
self.create_layer("fused_softmax", fused_softmax_cls)
def __call__(self, x: JTensor, mask: JTensor = None, bias: JTensor = None) -> JTensor:
"""__call__"""
return self.fused_softmax(x, mask, bias)
class Linear(TransformerEngineBaseLayer):
"""Linear"""
out_features: int = 512
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = True
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_axes: Tuple[str, ...] = ()
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
sharding_type: ShardingType = ShardingType.SINGLE
def setup(self) -> None:
"""setup"""
super().setup()
dense_general_cls = partial(
DenseGeneral,
features=self.out_features,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
kernel_axes=self.kernel_axes,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes=self.bias_axes,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
sharding_type=self.sharding_type)
self.create_layer("linear", dense_general_cls)
def __call__(self, x: JTensor) -> JTensor:
"""__call__"""
return self.linear(x)
class LayerNormLinear(TransformerEngineBaseLayer):
"""LayerNormLinear"""
out_features: int = 512
enable_layernorm: bool = True
layernorm_type: str = 'layernorm'
epsilon: float = 1e-6
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = WeightInit.Constant(1.0)
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_axes: Tuple[str, ...] = ()
return_layernorm_output: bool = True
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
depth_scaling: float = None
sharding_type: ShardingType = ShardingType.SINGLE
def setup(self) -> None:
"""setup"""
super().setup()
ln_dense_general_cls = partial(
LayerNormDenseGeneral,
features=self.out_features,
enable_layernorm=self.enable_layernorm,
layernorm_type=self.layernorm_type,
epsilon=self.epsilon,
zero_centered_gamma=self.zero_centered_gamma,
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
ln_bias_init=TransformerEngineBaseLayer.generate_params_init(
"ln_bias", self.ln_bias_init),
ln_bias_axes=self.ln_bias_axes,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
kernel_axes=self.kernel_axes,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes=self.bias_axes,
return_layernorm_output=self.return_layernorm_output,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
depth_scaling=self.depth_scaling,
sharding_type=self.sharding_type)
self.create_layer("ln_linear", ln_dense_general_cls)
def __call__(self, x: JTensor) -> JTensor:
"""__call__"""
return self.ln_linear(x)
class LayerNormMLP(TransformerEngineBaseLayer):
"""LayerNormMLP"""
intermediate_dim: int = 2048
enable_layernorm: bool = True
layernorm_type: str = 'layernorm'
epsilon: float = 1e-6
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = WeightInit.Constant(1.0)
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes_1: Tuple[str, ...] = ()
kernel_axes_2: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_axes_1: Tuple[str, ...] = ()
bias_axes_2: Tuple[str, ...] = ()
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ('relu',)
intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = ()
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
major_sharding_type: MajorShardingType = MajorShardingType.SINGLE
def setup(self) -> None:
"""setup"""
super().setup()
ln_mlp_cls = partial(
flax_LayerNormMLP,
intermediate_dim=self.intermediate_dim,
enable_layernorm=self.enable_layernorm,
layernorm_type=self.layernorm_type,
epsilon=self.epsilon,
zero_centered_gamma=self.zero_centered_gamma,
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
ln_bias_init=TransformerEngineBaseLayer.generate_params_init(
"ln_bias", self.ln_bias_init),
ln_bias_axes=self.ln_bias_axes,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
kernel_axes_1=self.kernel_axes_1,
kernel_axes_2=self.kernel_axes_2,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes_1=self.bias_axes_1,
bias_axes_2=self.bias_axes_2,
return_layernorm_output=self.return_layernorm_output,
activations=self.activations,
intermediate_dropout_rate=self.intermediate_dropout_rate,
intermediate_hidden_dropout_dims=self.intermediate_hidden_dropout_dims,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
major_sharding_type=self.major_sharding_type)
self.create_layer("ln_mlp", ln_mlp_cls)
def __call__(self, x: JTensor, deterministic: bool = False) -> JTensor:
"""__call__"""
return self.ln_mlp(x, deterministic)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Praxis Modules related Transformer
"""
from functools import partial
from typing import Optional, Sequence, Tuple
from praxis import pax_fiddle
from praxis.base_layer import WeightInit
from praxis.pytypes import JTensor
from .module import TransformerEngineBaseLayer
from ..flax.transformer import AttentionType, TransformerLayerType
from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention
from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases
from ..flax.transformer import TransformerLayer as flax_TransformerLayer
class RelativePositionBiases(TransformerEngineBaseLayer):
"""RelativePositionBiases"""
num_buckets: int = 32
max_distance: int = 128
num_attention_heads: int = 64
embedding_init: WeightInit = None
embedding_axes: Tuple[str, ...] = ()
@staticmethod
def generate_embedding_init(init, num_attention_heads, num_buckets):
"""generate_embedding_init"""
embedding_init = init
if embedding_init is None:
rb_stddev = (num_attention_heads * num_buckets)**-0.5
embedding_init = WeightInit.Gaussian(rb_stddev)
return embedding_init
def setup(self) -> None:
"""setup"""
super().setup()
embedding_init = RelativePositionBiases.generate_embedding_init(
self.embedding_init, self.num_attention_heads, self.num_buckets)
rpb_cls = partial(flax_RelativePositionBiases,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
num_attention_heads=self.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init),
embedding_axes=self.embedding_axes,
dtype=self.dtype)
self.create_layer("relative_position_bias", rpb_cls)
def __call__(self, q_seqlen: JTensor, k_seqlen: JTensor, bidirectional: bool = True) -> JTensor:
"""__call__"""
return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional)
class MultiHeadAttention(TransformerEngineBaseLayer):
"""MultiHeadAttention"""
head_dim: int = 64
num_heads: int = 16
dropout_rate: float = 0.
dropout_rng_name: str = 'dropout'
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
attn_type: AttentionType = AttentionType.PADDING
fuse_qkv: bool = True
transpose_batch_sequence: bool = True
scale_attn_logits: bool = False
scaled_query_init: bool = True
float32_logits: bool = False
def setup(self) -> None:
"""setup"""
super().setup()
mha_cls = partial(
flax_MultiHeadAttention,
dtype=self.dtype,
head_dim=self.head_dim,
num_heads=self.num_heads,
dropout_rate=self.dropout_rate,
dropout_rng_name=self.dropout_rng_name,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm,
attn_type=self.attn_type,
fuse_qkv=self.fuse_qkv,
transpose_batch_sequence=self.transpose_batch_sequence,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
float32_logits=self.float32_logits)
self.create_layer("multi_head_attn", mha_cls)
def __call__(self,
inputs_q: JTensor,
inputs_kv: JTensor,
mask: Optional[JTensor] = None,
bias: Optional[JTensor] = None,
*,
decode: bool = False,
deterministic: bool = False) -> JTensor:
"""__call__"""
return self.multi_head_attn(inputs_q,
inputs_kv,
mask,
bias,
decode=decode,
deterministic=deterministic)
class TransformerLayer(TransformerEngineBaseLayer):
"""TransformerLayer"""
hidden_size: int = 512
mlp_hidden_size: int = 2048
num_attention_heads: int = 8
layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
dropout_rng_name: str = 'dropout'
mlp_activations: Sequence[str] = ('relu',)
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER
enable_relative_embedding: bool = True
relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
drop_path: float = 0.0
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
def setup(self) -> None:
"""setup"""
super().setup()
relative_embedding_flax_module = None
if self.enable_relative_embedding and self.relative_embedding is not None:
assert self.relative_embedding.num_attention_heads == \
self.num_attention_heads, \
"TransformerLayer.relative_embedding.num_attention_heads shoule be" \
"the same as TransformerLayer.num_attention_heads."
embedding_init = RelativePositionBiases.generate_embedding_init(
self.relative_embedding.embedding_init, self.relative_embedding.num_attention_heads,
self.relative_embedding.num_buckets)
relative_embedding_flax_module = flax_RelativePositionBiases(
num_buckets=self.relative_embedding.num_buckets,
max_distance=self.relative_embedding.max_distance,
num_attention_heads=self.relative_embedding.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init),
embedding_axes=self.relative_embedding.embedding_axes,
dtype=self.relative_embedding.dtype)
transformerlayer_cls = partial(
flax_TransformerLayer,
dtype=self.dtype,
hidden_size=self.hidden_size,
mlp_hidden_size=self.mlp_hidden_size,
num_attention_heads=self.num_attention_heads,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
hidden_dropout=self.hidden_dropout,
hidden_dropout_dims=self.hidden_dropout_dims,
attention_dropout=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name,
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mha_kernel", self.params_init),
mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mlp_kernel", self.params_init),
mlp_activations=self.mlp_activations,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm,
float32_attention_logits=self.float32_attention_logits,
layer_type=self.layer_type,
enable_relative_embedding=self.enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
drop_path=self.drop_path,
fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init)
self.create_layer("transformerlayer", transformerlayer_cls)
def __call__(self,
inputs: JTensor,
encoded: JTensor = None,
attention_mask: JTensor = None,
encoder_decoder_mask: JTensor = None,
deterministic: bool = False,
decode: bool = False,
max_decode_length: bool = None) -> JTensor:
"""__call__"""
return self.transformerlayer(inputs, encoded, attention_mask, encoder_decoder_mask,
deterministic, decode, max_decode_length)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment