Unverified Commit b20c0531 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

The Implementation of Praxis's Modules (#158)



* Adding JAX/Praxis modules and dependencies.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding UTs to JAX/Praxis modules.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Remove praxis as a dependency due to not strictly needed
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Repalce is_fp8_supported to is_fp8_available
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Make Praxis as an optional dependency.

1. Removed 'from . import praxis' in __init__.py.
    1.1 Noted, keep 'from . import flax' for deprecated warning.
2. Changed te.flax to te_flax in examples and README.rst.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding a workaround to FP8 training on Praxis.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent 68f60b89
...@@ -77,6 +77,7 @@ Flax ...@@ -77,6 +77,7 @@ Flax
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.common import recipe from transformer_engine.common import recipe
BATCH = 32 BATCH = 32
...@@ -93,7 +94,7 @@ Flax ...@@ -93,7 +94,7 @@ Flax
# Enable autocasting for the forward pass # Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
model = te.flax.DenseGeneral(features=HIDDEN) model = te_flax.DenseGeneral(features=HIDDEN)
def loss_fn(params, other_vars, inp): def loss_fn(params, other_vars, inp):
out = model.apply({'params':params, **other_vars}, inp) out = model.apply({'params':params, **other_vars}, inp)
......
...@@ -20,6 +20,7 @@ from jax.experimental import mesh_utils ...@@ -20,6 +20,7 @@ from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit from jax.experimental.pjit import pjit
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
DEVICE_DP_AXIS = 'data' DEVICE_DP_AXIS = 'data'
DEVICE_TP_AXIS = 'model' DEVICE_TP_AXIS = 'model'
...@@ -39,27 +40,27 @@ class Net(nn.Module): ...@@ -39,27 +40,27 @@ class Net(nn.Module):
def __call__(self, x, mask, disable_dropout=False): def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.flax.TransformerLayer, te_Encoder = partial(te_flax.TransformerLayer,
hidden_size=256, hidden_size=256,
mlp_hidden_size=1024, mlp_hidden_size=1024,
num_attention_heads=8, num_attention_heads=8,
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, dropout_rng_name=DROPOUT_KEY,
layer_type=te.flax.TransformerLayerType.ENCODER, layer_type=te_flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
x = te.flax.DenseGeneral(features=256, x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,), bias_axes=(NAMED_TP_AXIS,),
sharding_type=te.ShardingType.DP_TP_COL, sharding_type=te.ShardingType.DP_TP_COL,
dtype=jnp.bfloat16)(x) dtype=jnp.bfloat16)(x)
x = te.flax.DenseGeneral(features=256, x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,), bias_axes=(NAMED_BROADCAST_AXIS,),
sharding_type=te.ShardingType.DP_TP_ROW, sharding_type=te.ShardingType.DP_TP_ROW,
...@@ -174,9 +175,7 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): ...@@ -174,9 +175,7 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
else: else:
tensor[i] = vocab[word] tensor[i] = vocab[word]
seq_len = len(tokens) seq_len = min(len(tokens), max_seq_len)
if seq_len > max_seq_len:
seq_len = max_seq_len
mask_2d = mask_3d[j] mask_2d = mask_3d[j]
mask_2d[:seq_len, :seq_len] = 0 mask_2d[:seq_len, :seq_len] = 0
...@@ -275,7 +274,7 @@ def train_and_evaluate(args): ...@@ -275,7 +274,7 @@ def train_and_evaluate(args):
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te.flax.extend_logical_axis_rules(tuple()) + customized_rules sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
params_pspec = get_params_pspec(sharding_rules, abs_var_collect) params_pspec = get_params_pspec(sharding_rules, abs_var_collect)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
......
...@@ -20,6 +20,7 @@ from jax.experimental import mesh_utils ...@@ -20,6 +20,7 @@ from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit from jax.experimental.pjit import pjit
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
DEVICE_DP_AXIS = 'data' DEVICE_DP_AXIS = 'data'
PARAMS_KEY = 'params' PARAMS_KEY = 'params'
...@@ -36,24 +37,24 @@ class Net(nn.Module): ...@@ -36,24 +37,24 @@ class Net(nn.Module):
def __call__(self, x, mask, disable_dropout=False): def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.flax.TransformerLayer, te_Encoder = partial(te_flax.TransformerLayer,
hidden_size=256, hidden_size=256,
mlp_hidden_size=1024, mlp_hidden_size=1024,
num_attention_heads=8, num_attention_heads=8,
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, dropout_rng_name=DROPOUT_KEY,
layer_type=te.flax.TransformerLayerType.ENCODER, layer_type=te_flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
x = te.flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, x = te_flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP,
dtype=jnp.bfloat16)(x) dtype=jnp.bfloat16)(x)
x = te.flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, x = te_flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP,
dtype=jnp.bfloat16)(x) dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
...@@ -165,9 +166,7 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): ...@@ -165,9 +166,7 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
else: else:
tensor[i] = vocab[word] tensor[i] = vocab[word]
seq_len = len(tokens) seq_len = min(len(tokens), max_seq_len)
if seq_len > max_seq_len:
seq_len = max_seq_len
mask_2d = mask_3d[j] mask_2d = mask_3d[j]
mask_2d[:seq_len, :seq_len] = 0 mask_2d[:seq_len, :seq_len] = 0
...@@ -257,7 +256,7 @@ def train_and_evaluate(args): ...@@ -257,7 +256,7 @@ def train_and_evaluate(args):
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
sharding_rules = te.flax.extend_logical_axis_rules(tuple()) sharding_rules = te_flax.extend_logical_axis_rules(tuple())
params_pspec = get_params_pspec(sharding_rules, abs_var_collect) params_pspec = get_params_pspec(sharding_rules, abs_var_collect)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
......
...@@ -22,6 +22,7 @@ from jax.experimental import mesh_utils ...@@ -22,6 +22,7 @@ from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit from jax.experimental.pjit import pjit
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
DEVICE_DP_AXIS = 'data' DEVICE_DP_AXIS = 'data'
...@@ -42,27 +43,27 @@ class Net(nn.Module): ...@@ -42,27 +43,27 @@ class Net(nn.Module):
def __call__(self, x, mask, disable_dropout=False): def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.flax.TransformerLayer, te_Encoder = partial(te_flax.TransformerLayer,
hidden_size=256, hidden_size=256,
mlp_hidden_size=1024, mlp_hidden_size=1024,
num_attention_heads=8, num_attention_heads=8,
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, dropout_rng_name=DROPOUT_KEY,
layer_type=te.flax.TransformerLayerType.ENCODER, layer_type=te_flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
x = te.flax.DenseGeneral(features=256, x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,), bias_axes=(NAMED_TP_AXIS,),
sharding_type=te.ShardingType.DP_TP_COL, sharding_type=te.ShardingType.DP_TP_COL,
dtype=jnp.bfloat16)(x) dtype=jnp.bfloat16)(x)
x = te.flax.DenseGeneral(features=256, x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,), bias_axes=(NAMED_BROADCAST_AXIS,),
sharding_type=te.ShardingType.DP_TP_ROW, sharding_type=te.ShardingType.DP_TP_ROW,
...@@ -248,9 +249,7 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): ...@@ -248,9 +249,7 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
else: else:
tensor[i] = vocab[word] tensor[i] = vocab[word]
seq_len = len(tokens) seq_len = min(len(tokens), max_seq_len)
if seq_len > max_seq_len:
seq_len = max_seq_len
mask_2d = mask_3d[j] mask_2d = mask_3d[j]
mask_2d[:seq_len, :seq_len] = 0 mask_2d[:seq_len, :seq_len] = 0
...@@ -356,7 +355,7 @@ def train_and_evaluate(args): ...@@ -356,7 +355,7 @@ def train_and_evaluate(args):
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te.flax.extend_logical_axis_rules(tuple()) + customized_rules sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
params_pspec = get_params_pspec(sharding_rules, abs_var_collect) params_pspec = get_params_pspec(sharding_rules, abs_var_collect)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
......
...@@ -17,6 +17,7 @@ from flax.core.frozen_dict import FrozenDict ...@@ -17,6 +17,7 @@ from flax.core.frozen_dict import FrozenDict
from flax.training import train_state from flax.training import train_state
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
PARAMS_KEY = 'params' PARAMS_KEY = 'params'
DROPOUT_KEY = 'dropout' DROPOUT_KEY = 'dropout'
...@@ -31,23 +32,23 @@ class Net(nn.Module): ...@@ -31,23 +32,23 @@ class Net(nn.Module):
def __call__(self, x, mask, disable_dropout=False): def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.flax.TransformerLayer, te_Encoder = partial(te_flax.TransformerLayer,
hidden_size=256, hidden_size=256,
mlp_hidden_size=1024, mlp_hidden_size=1024,
num_attention_heads=8, num_attention_heads=8,
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, dropout_rng_name=DROPOUT_KEY,
layer_type=te.flax.TransformerLayerType.ENCODER, layer_type=te_flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
x = te.flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te.flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x return x
...@@ -160,9 +161,7 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): ...@@ -160,9 +161,7 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
else: else:
tensor[i] = vocab[word] tensor[i] = vocab[word]
seq_len = len(tokens) seq_len = min(len(tokens), max_seq_len)
if seq_len > max_seq_len:
seq_len = max_seq_len
mask_2d = mask_3d[j] mask_2d = mask_3d[j]
mask_2d[:seq_len, :seq_len] = 0 mask_2d[:seq_len, :seq_len] = 0
......
...@@ -16,6 +16,7 @@ from flax.core.frozen_dict import FrozenDict ...@@ -16,6 +16,7 @@ from flax.core.frozen_dict import FrozenDict
from flax.training import train_state from flax.training import train_state
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
IMAGE_H = 28 IMAGE_H = 28
IMAGE_W = 28 IMAGE_W = 28
...@@ -32,7 +33,7 @@ class Net(nn.Module): ...@@ -32,7 +33,7 @@ class Net(nn.Module):
@nn.compact @nn.compact
def __call__(self, x, disable_dropout=False): def __call__(self, x, disable_dropout=False):
if self.use_te: if self.use_te:
nn_Dense = te.flax.DenseGeneral nn_Dense = te_flax.DenseGeneral
else: else:
nn_Dense = nn.Dense nn_Dense = nn.Dense
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from functools import partial
from typing import Dict
import jax
import jax.numpy as jnp
from praxis import pax_fiddle
from praxis.base_layer import WeightInit, DEFAULT_INIT_MUTABLE_LIST
import pytest
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
from transformer_engine.jax.flax import LayerNorm as flax_LayerNorm
from transformer_engine.jax.flax import LayerNormMLP as flax_LayerNormMLP
from transformer_engine.jax.flax import MultiHeadAttention as flax_MultiHeadAttention
from transformer_engine.jax.flax import RelativePositionBiases as flax_RelativePositionBiases
from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLayer
from transformer_engine.jax.flax.module import Softmax
from transformer_engine.jax.flax.transformer import AttentionType
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from transformer_engine.jax.praxis import LayerNorm
from transformer_engine.jax.praxis import FusedSoftmax, LayerNorm
from transformer_engine.jax.praxis import LayerNormLinear, LayerNormMLP, Linear
from transformer_engine.jax.praxis import MultiHeadAttention, RelativePositionBiases
from transformer_engine.jax.praxis import TransformerEngineBaseLayer, TransformerLayer, TransformerLayerType
from transformer_engine.jax.softmax import SoftmaxType
from utils import assert_allclose
is_fp8_supported, reason = is_fp8_available()
DATA_SHAPE = [(128, 32, 512), (512, 32, 512)]
DTYPE = [jnp.float32, jnp.bfloat16]
ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd:
assert key in test_fd, \
f"{key} not found in test FrozenDict {test_fd}"
assert isinstance(test_fd[key], type(ref_fd[key])), \
f"The data type is not match between ref and test " \
f" Dict on {key=}"
if isinstance(ref_fd[key], Dict):
compare_dict(ref_fd[key], test_fd[key], rtol, atol)
else:
assert_allclose(ref_fd[key],
test_fd[key],
rtol=rtol,
atol=atol,
err_msg=f"{key=} is not close")
class TestLayer:
@staticmethod
def loss(inner_variables, *inner_inputs, module, mean_out=True):
outs = module.apply(inner_variables, *inner_inputs)
out = outs
if isinstance(outs, tuple):
# The first place of outs is the real output, others
# are auxiliary values.
out = outs[0]
return jnp.mean(out) if mean_out else out
@staticmethod
def loss_and_grads(module, variables, *inputs):
grad_fn = jax.value_and_grad(TestLayer.loss, argnums=(0, 1))
loss_val, (wgrads, dgrad) = grad_fn(variables, *inputs, module=module)
if FP8Helper.is_fp8_enabled():
wgrads = update_fp8_metas(wgrads)
return loss_val, wgrads, dgrad
def input_getter(self, shape, dtype):
raise NotImplementedError
def get_layer_name(self):
raise NotImplementedError
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
raise NotImplementedError
def sync_variables(self, praxis_variables, flax_variables):
synced_praxis_variables = praxis_variables
lyr_name = self.get_layer_name()
synced_praxis_variables['params'][lyr_name]['cld'] = \
flax_variables['params'].unfreeze()
return synced_praxis_variables, flax_variables
def sync_wgrads(self, praxis_wgrads, flax_wgrads):
synced_praxis_grads = praxis_wgrads
lyr_name = self.get_layer_name()
synced_praxis_grads['params'] = \
synced_praxis_grads['params'][lyr_name]['cld']
if FP8Helper.is_fp8_enabled():
synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = \
synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME][lyr_name]['cld']
return synced_praxis_grads, flax_wgrads.unfreeze()
def forward_backward_runner(self,
data_shape,
dtype,
praxis_p,
flax_cls,
rtol=1e-05,
atol=1e-08):
init_key = jax.random.PRNGKey(seed=1234)
test_inputs = self.input_getter(data_shape, dtype)
praxis_layer = praxis_p.Instantiate()
# This is a workaround to correctly enable FP8 meta generation for Praxis.
# TODO (Ming Huang): To come out a better solution.
mutable_list = DEFAULT_INIT_MUTABLE_LIST + [FP8Helper.FP8_COLLECTION_NAME]
praxis_variables = praxis_layer.init(init_key, *test_inputs, mutable=mutable_list)
flax_layer = flax_cls()
flax_variables = flax_layer.init(init_key, *test_inputs)
if "params_axes" in flax_variables:
flax_variables, _ = flax_variables.pop("params_axes")
if FP8Helper.is_fp8_enabled():
flax_variables, _ = flax_variables.pop(FP8Helper.FP8_COLLECTION_NAME + "_axes")
praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)
iter_times = 5 if FP8Helper.is_fp8_enabled() else 1
for _ in range(iter_times):
praxis_loss, praxis_wgrads, praxis_dgrad = \
TestLayer.loss_and_grads(praxis_layer, praxis_variables, *test_inputs)
flax_loss, flax_wgrads, flax_dgrad = \
TestLayer.loss_and_grads(flax_layer, flax_variables, *test_inputs)
if FP8Helper.is_fp8_enabled():
praxis_wgrads.pop('params')
praxis_variables = update_collections(praxis_wgrads, praxis_variables)
flax_wgrads, _ = flax_wgrads.pop('params')
flax_variables = update_collections(flax_wgrads, flax_variables)
praxis_loss, praxis_wgrads, praxis_dgrad = \
TestLayer.loss_and_grads(praxis_layer, praxis_variables, *test_inputs)
flax_loss, flax_wgrads, flax_dgrad = \
TestLayer.loss_and_grads(flax_layer, flax_variables, *test_inputs)
assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
assert_allclose(praxis_dgrad, flax_dgrad, rtol=rtol, atol=atol)
praxis_wgrads, flax_wgrads = self.sync_wgrads(praxis_wgrads, flax_wgrads)
compare_dict(praxis_wgrads, flax_wgrads, rtol=rtol, atol=atol)
class LayerNormAttr:
LN_TYPE = 'layernorm_type'
ZERO_CEN = 'zero_centered_gamma'
ATTRS = [{
LN_TYPE: "layernorm",
ZERO_CEN: False
}, {
LN_TYPE: "layernorm",
ZERO_CEN: True
}, {
LN_TYPE: "rmsnorm",
ZERO_CEN: False
}]
class TestLayerNorm(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self):
return 'layer_norm'
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
layernorm_type = attrs[LayerNormAttr.LN_TYPE]
zero_centered_gamma = attrs[LayerNormAttr.ZERO_CEN]
scale_init = None
bias_init = WeightInit.Constant(0.0)
transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(LayerNorm,
name='layer_norm',
dtype=dtype,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
scale_init=scale_init,
bias_init=bias_init,
transpose_batch_sequence=transpose_batch_sequence)
flax_cls = partial(flax_LayerNorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
scale_init=scale_init,
bias_init=TransformerEngineBaseLayer.generate_params_init(
"ln_bias", bias_init),
dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', LayerNormAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class FusedSoftmaxAttr:
SCALE_FACTOR = 'scale_factor'
ST_TYPE = 'softmax_type'
ATTRS = [{
SCALE_FACTOR: 0.0,
ST_TYPE: SoftmaxType.SCALED
}, {
SCALE_FACTOR: 0.0,
ST_TYPE: SoftmaxType.SCALED_MASKED
}, {
SCALE_FACTOR: 0.0,
ST_TYPE: SoftmaxType.SCALED_UPPER_TRIANG_MASKED
}]
class TestFusedSoftmax(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return jax.random.normal(data_key, shape, dtype), \
jnp.ones(shape, dtype=jnp.uint8) # Masks
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
scale_factor = attrs[FusedSoftmaxAttr.SCALE_FACTOR]
softmax_type = attrs[FusedSoftmaxAttr.ST_TYPE]
praxis_p = pax_fiddle.Config(FusedSoftmax,
name='fused_softmax',
scale_factor=scale_factor,
softmax_type=softmax_type)
flax_cls = partial(Softmax, scale_factor=scale_factor, softmax_type=softmax_type)
return praxis_p, flax_cls
def sync_variables(self, praxis_variables, flax_variables):
return praxis_variables, flax_variables
def sync_wgrads(self, praxis_wgrads, flax_wgrads):
return praxis_wgrads, flax_wgrads
@pytest.mark.parametrize('data_shape', [(32, 1, 128, 128), (32, 1, 512, 128)])
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', FusedSoftmaxAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
if (attrs[FusedSoftmaxAttr.ST_TYPE] == SoftmaxType.SCALED_UPPER_TRIANG_MASKED) and \
(data_shape[-2] != data_shape[-1]):
pass # Skip, due to not support
else:
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class LinearAttr:
FEATURE = 'features'
USE_BIAS = 'use_bias'
ATTRS = [{
FEATURE: 512,
USE_BIAS: False
}, {
FEATURE: 512,
USE_BIAS: True
}, {
FEATURE: 1024,
USE_BIAS: False
}, {
FEATURE: 1024,
USE_BIAS: True
}]
class TestLinear(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self):
return 'linear'
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
out_features = attrs[LinearAttr.FEATURE]
kernel_init = WeightInit.Gaussian(1.0)
use_bias = attrs[LinearAttr.USE_BIAS]
bias_init = WeightInit.Constant(0.0)
axis = -1
transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(Linear,
name='linear',
dtype=dtype,
out_features=out_features,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
axis=axis,
transpose_batch_sequence=transpose_batch_sequence)
flax_cls = partial(
DenseGeneral,
features=out_features,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
use_bias=use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
axis=axis,
dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', LinearAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', LinearAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
def test_forward_backward_fp8(self,
data_shape,
dtype,
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class LayerNormLinearAttr:
FEATURE = 'features'
USE_BIAS = 'use_bias'
ENABLE_LN = 'enable_layernorm'
LN_TYPE = 'layernorm_type'
ZERO_CEN = 'zero_centered_gamma'
ATTRS = [{
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: False,
LN_TYPE: 'layernorm',
ZERO_CEN: False
}]
class TestLayerNormLinear(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self):
return 'ln_linear'
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
out_features = attrs[LayerNormLinearAttr.FEATURE]
enable_layernorm = attrs[LayerNormLinearAttr.ENABLE_LN]
layernorm_type = attrs[LayerNormLinearAttr.LN_TYPE]
zero_centered_gamma = attrs[LayerNormLinearAttr.ZERO_CEN]
kernel_init = WeightInit.Gaussian(1.0)
use_bias = attrs[LayerNormLinearAttr.USE_BIAS]
bias_init = WeightInit.Constant(0.0)
axis = -1
transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(LayerNormLinear,
name='ln_linear',
dtype=dtype,
out_features=out_features,
enable_layernorm=enable_layernorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
axis=axis,
transpose_batch_sequence=transpose_batch_sequence)
flax_cls = partial(
LayerNormDenseGeneral,
features=out_features,
enable_layernorm=enable_layernorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
use_bias=use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
axis=axis,
dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', LayerNormLinearAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', LayerNormLinearAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
def test_forward_backward_fp8(self,
data_shape,
dtype,
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class LayerNormMLPAttr:
INTERMEDIATE_DIM = 'intermediate_dim'
USE_BIAS = 'use_bias'
ENABLE_LN = 'enable_layernorm'
LN_TYPE = 'layernorm_type'
ZERO_CEN = 'zero_centered_gamma'
ACTIVATION = 'activations'
ATTRS = [{
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('relu',)
}, {
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('relu',)
}, {
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('relu',)
}, {
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear')
}, {
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear')
}]
class TestLayerNormMLP(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self):
return 'ln_mlp'
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
intermediate_dim = attrs[LayerNormMLPAttr.INTERMEDIATE_DIM]
enable_layernorm = attrs[LayerNormMLPAttr.ENABLE_LN]
layernorm_type = attrs[LayerNormMLPAttr.LN_TYPE]
zero_centered_gamma = attrs[LayerNormMLPAttr.ZERO_CEN]
kernel_init = WeightInit.Gaussian(1.0)
use_bias = attrs[LayerNormMLPAttr.USE_BIAS]
bias_init = WeightInit.Constant(0.0)
activations = attrs[LayerNormMLPAttr.ACTIVATION]
axis = -1
transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(LayerNormMLP,
name='ln_mlp',
dtype=dtype,
intermediate_dim=intermediate_dim,
enable_layernorm=enable_layernorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
activations=activations,
intermediate_dropout_rate=0.0,
axis=axis,
transpose_batch_sequence=transpose_batch_sequence)
flax_cls = partial(
flax_LayerNormMLP,
intermediate_dim=intermediate_dim,
enable_layernorm=enable_layernorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
use_bias=use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
activations=activations,
intermediate_dropout_rate=0.0,
axis=axis,
dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', LayerNormMLPAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', LayerNormMLPAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
def test_forward_backward_fp8(self,
data_shape,
dtype,
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class TestRelativePositionBias(TestLayer):
def get_layer_name(self):
return 'relative_position_bias'
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
num_buckets = 32
max_distance = 128
num_attention_heads = 64
rb_stddev = (num_attention_heads * num_buckets)**-0.5
embedding_init = WeightInit.Gaussian(rb_stddev)
praxis_p = pax_fiddle.Config(RelativePositionBiases,
name='relative_position_bias',
dtype=dtype,
num_buckets=num_buckets,
max_distance=max_distance,
num_attention_heads=num_attention_heads,
embedding_init=embedding_init)
flax_cls = partial(flax_RelativePositionBiases,
num_buckets=num_buckets,
max_distance=max_distance,
num_attention_heads=num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init),
dtype=dtype)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', [{}])
def test_forward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
init_key = jax.random.PRNGKey(seed=1234)
test_inputs = [(128, 128, True), (128, 128, False)]
for test_input in test_inputs:
praxis_layer = praxis_p.Instantiate()
praxis_variables = praxis_layer.init(init_key, *test_input)
flax_layer = flax_cls()
flax_variables = flax_layer.init(init_key, *test_input)
if "params_axes" in flax_variables:
flax_variables, _ = flax_variables.pop("params_axes")
if FP8Helper.is_fp8_enabled():
flax_variables, _ = flax_variables.pop(FP8Helper.FP8_COLLECTION_NAME + "_axes")
praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)
praxis_loss= \
TestLayer.loss(praxis_variables, *test_input, module=praxis_layer, mean_out=False)
flax_loss = \
TestLayer.loss(flax_variables, *test_input, module=flax_layer, mean_out=False)
assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
class MultiHeadAttnAttr:
USE_BIAS = 'use_bias'
LN_TYPE = 'layernorm_type'
ATTN_TYPE = 'attn_type'
ZERO_CEN = 'zero_centered_gamma'
ATTRS = [{
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ATTN_TYPE: AttentionType.PADDING
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ATTN_TYPE: AttentionType.PADDING
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ATTN_TYPE: AttentionType.PADDING
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ATTN_TYPE: AttentionType.CAUSAL
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ATTN_TYPE: AttentionType.CAUSAL
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ATTN_TYPE: AttentionType.CAUSAL
}]
class TestMultiHeadAttn(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape,
dtype), jax.random.normal(data_key, shape, dtype))
def get_layer_name(self):
return 'multi_head_attn'
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
head_dim = 64
num_heads = 16
layernorm_type = attrs[MultiHeadAttnAttr.LN_TYPE]
zero_centered_gamma = attrs[MultiHeadAttnAttr.ZERO_CEN]
kernel_init = WeightInit.Gaussian(1.0)
use_bias = attrs[MultiHeadAttnAttr.USE_BIAS]
bias_init = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm = False
output_layernorm = False
attn_type = attrs[MultiHeadAttnAttr.ATTN_TYPE]
fuse_qkv: bool = True
transpose_batch_sequence = True
scale_attn_logits = False
scaled_query_init = True
float32_logits = False
praxis_p = pax_fiddle.Config(
MultiHeadAttention,
name='mha',
dtype=dtype,
head_dim=head_dim,
num_heads=num_heads,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm,
output_layernorm=output_layernorm,
attn_type=attn_type,
fuse_qkv=fuse_qkv,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init,
float32_logits=float32_logits)
flax_cls = partial(
flax_MultiHeadAttention,
dtype=dtype,
head_dim=head_dim,
num_heads=num_heads,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
use_bias=use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm,
output_layernorm=output_layernorm,
attn_type=attn_type,
fuse_qkv=fuse_qkv,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init,
float32_logits=float32_logits)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
def test_forward_backward_fp8(self,
data_shape,
dtype,
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class TransformerLayerAttr:
USE_BIAS = 'use_bias'
LN_TYPE = 'layernorm_type'
ACTIVATION = 'activations'
LYR_TYPE = 'layer_type'
ZERO_CEN = 'zero_centered_gamma'
TRANSPOSE_BS = 'transpose_batch_sequence'
ATTRS = [{
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
TRANSPOSE_BS: False
}]
class TestTransformer(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape,
dtype), jax.random.normal(data_key, shape, dtype))
def get_layer_name(self):
return 'transformerlayer'
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
hidden_size = 512
mlp_hidden_size = 2048
num_attention_heads = 8
layernorm_type = attrs[TransformerLayerAttr.LN_TYPE]
hidden_dropout = 0.0
attention_dropout = 0.0
mlp_activations = attrs[TransformerLayerAttr.ACTIVATION]
kernel_init = WeightInit.Gaussian(1.0)
use_bias = attrs[TransformerLayerAttr.USE_BIAS]
bias_init = WeightInit.Constant(0.0)
layer_type = attrs[TransformerLayerAttr.LYR_TYPE]
enable_relative_embedding = True
relative_embedding = pax_fiddle.Config(RelativePositionBiases,
num_attention_heads=num_attention_heads)
drop_path = 0.0
transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS]
rel_embedding_init = RelativePositionBiases.generate_embedding_init(
relative_embedding.embedding_init, relative_embedding.num_attention_heads,
relative_embedding.num_buckets)
relative_embedding_flax_module = flax_RelativePositionBiases(
num_buckets=relative_embedding.num_buckets,
max_distance=relative_embedding.max_distance,
num_attention_heads=relative_embedding.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", rel_embedding_init),
embedding_axes=relative_embedding.embedding_axes,
dtype=relative_embedding.dtype)
praxis_p = pax_fiddle.Config(TransformerLayer,
name='transformer_layer',
params_init=kernel_init,
dtype=dtype,
hidden_size=hidden_size,
mlp_hidden_size=mlp_hidden_size,
num_attention_heads=num_attention_heads,
layernorm_type=layernorm_type,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
mlp_activations=mlp_activations,
use_bias=use_bias,
bias_init=bias_init,
layer_type=layer_type,
enable_relative_embedding=enable_relative_embedding,
relative_embedding=relative_embedding,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence)
flax_cls = partial(flax_TransformerLayer,
dtype=dtype,
hidden_size=hidden_size,
mlp_hidden_size=mlp_hidden_size,
num_attention_heads=num_attention_heads,
layernorm_type=layernorm_type,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
mlp_activations=mlp_activations,
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mha_kernel", kernel_init),
mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mlp_kernel", kernel_init),
use_bias=use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init(
"bias", bias_init),
layer_type=layer_type,
enable_relative_embedding=enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
def test_forward_backward_fp8(self,
data_shape,
dtype,
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
...@@ -36,7 +36,7 @@ TransformerLayerType = deprecate_wrapper( ...@@ -36,7 +36,7 @@ TransformerLayerType = deprecate_wrapper(
__all__ = [ __all__ = [
'fp8_autocast', 'update_collections', 'update_fp8_metas', 'get_delayed_scaling', 'fp8_autocast', 'update_collections', 'update_fp8_metas', 'get_delayed_scaling',
'MajorShardingType', 'ShardingResource', 'ShardingType', 'flax', 'DenseGeneral', 'LayerNorm', 'MajorShardingType', 'ShardingResource', 'ShardingType', 'flax', 'praxis', 'DenseGeneral',
'LayerNormDenseGeneral', 'LayerNormMLP', 'TransformerEngineBase', 'MultiHeadAttention', 'LayerNorm', 'LayerNormDenseGeneral', 'LayerNormMLP', 'TransformerEngineBase',
'RelativePositionBiases', 'TransformerLayer', 'TransformerLayerType' 'MultiHeadAttention', 'RelativePositionBiases', 'TransformerLayer', 'TransformerLayerType'
] ]
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Praxis related Modules"""
from .module import FusedSoftmax, LayerNorm
from .module import LayerNormLinear, LayerNormMLP, Linear, TransformerEngineBaseLayer
from .transformer import MultiHeadAttention, RelativePositionBiases, TransformerLayer
from ..flax.transformer import TransformerLayerType
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Praxis Modules
"""
from functools import partial
from typing import Callable, Iterable, Sequence, Tuple, Union
from praxis import pax_fiddle
from praxis.base_layer import init_var
from praxis.base_layer import BaseLayer, WeightInit, WeightHParams
from praxis.layers import flax_adapter
from praxis.pytypes import JTensor
from ..flax.module import DenseGeneral, LayerNormDenseGeneral
from ..flax.module import LayerNorm as flax_LayerNorm
from ..flax.module import LayerNormMLP as flax_LayerNormMLP
from ..flax.module import Softmax
from ..softmax import SoftmaxType
from ..sharding import MajorShardingType, ShardingType
def _generate_ln_scale_init(scale_init):
if scale_init is not None:
return TransformerEngineBaseLayer.generate_params_init("scale", scale_init)
return scale_init
class TransformerEngineBaseLayer(BaseLayer):
"""TransformerEngineBaseLayer"""
logical_axes_rules: Tuple[Tuple, ...] = None
@staticmethod
def generate_params_init(name: str, initializer: WeightInit):
"""generate_params_init"""
def kernel_init(key, shape, dtype):
wp = WeightHParams(shape=shape, init=initializer, dtype=dtype)
return init_var(wp, key, name)
return kernel_init
def create_layer(self, name, flax_module_cls):
"""create_layer"""
flax_module_p = pax_fiddle.Config(flax_adapter.FlaxModuleAdapter,
module_factory_method=flax_module_cls,
logical_axes_rules=self.logical_axes_rules,
ici_mesh_shape=self.ici_mesh_shape,
dcn_mesh_shape=self.dcn_mesh_shape,
mesh_axis_names=self.mesh_axis_names)
self.create_child(name, flax_module_p.clone())
class LayerNorm(TransformerEngineBaseLayer):
"""LayerNorm"""
epsilon: float = 1e-6
layernorm_type: str = 'layernorm'
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False
sharding_type: ShardingType = ShardingType.SINGLE
def setup(self) -> None:
"""setup"""
super().setup()
ln_cls = partial(flax_LayerNorm,
epsilon=self.epsilon,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
bias_init=TransformerEngineBaseLayer.generate_params_init(
"ln_bias", self.bias_init),
bias_axes=self.bias_axes,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
sharding_type=self.sharding_type)
self.create_layer("layer_norm", ln_cls)
def __call__(self, x: JTensor) -> JTensor:
"""__call__"""
return self.layer_norm(x)
class FusedSoftmax(TransformerEngineBaseLayer):
"""FusedSoftmax"""
scale_factor: float = 1.0
softmax_type: SoftmaxType = SoftmaxType.SCALED
sharding_type: ShardingType = ShardingType.SINGLE
def setup(self) -> None:
"""setup"""
super().setup()
fused_softmax_cls = partial(Softmax,
scale_factor=self.scale_factor,
softmax_type=self.softmax_type,
sharding_type=self.sharding_type)
self.create_layer("fused_softmax", fused_softmax_cls)
def __call__(self, x: JTensor, mask: JTensor = None, bias: JTensor = None) -> JTensor:
"""__call__"""
return self.fused_softmax(x, mask, bias)
class Linear(TransformerEngineBaseLayer):
"""Linear"""
out_features: int = 512
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = True
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_axes: Tuple[str, ...] = ()
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
sharding_type: ShardingType = ShardingType.SINGLE
def setup(self) -> None:
"""setup"""
super().setup()
dense_general_cls = partial(
DenseGeneral,
features=self.out_features,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
kernel_axes=self.kernel_axes,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes=self.bias_axes,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
sharding_type=self.sharding_type)
self.create_layer("linear", dense_general_cls)
def __call__(self, x: JTensor) -> JTensor:
"""__call__"""
return self.linear(x)
class LayerNormLinear(TransformerEngineBaseLayer):
"""LayerNormLinear"""
out_features: int = 512
enable_layernorm: bool = True
layernorm_type: str = 'layernorm'
epsilon: float = 1e-6
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = WeightInit.Constant(1.0)
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_axes: Tuple[str, ...] = ()
return_layernorm_output: bool = True
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
depth_scaling: float = None
sharding_type: ShardingType = ShardingType.SINGLE
def setup(self) -> None:
"""setup"""
super().setup()
ln_dense_general_cls = partial(
LayerNormDenseGeneral,
features=self.out_features,
enable_layernorm=self.enable_layernorm,
layernorm_type=self.layernorm_type,
epsilon=self.epsilon,
zero_centered_gamma=self.zero_centered_gamma,
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
ln_bias_init=TransformerEngineBaseLayer.generate_params_init(
"ln_bias", self.ln_bias_init),
ln_bias_axes=self.ln_bias_axes,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
kernel_axes=self.kernel_axes,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes=self.bias_axes,
return_layernorm_output=self.return_layernorm_output,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
depth_scaling=self.depth_scaling,
sharding_type=self.sharding_type)
self.create_layer("ln_linear", ln_dense_general_cls)
def __call__(self, x: JTensor) -> JTensor:
"""__call__"""
return self.ln_linear(x)
class LayerNormMLP(TransformerEngineBaseLayer):
"""LayerNormMLP"""
intermediate_dim: int = 2048
enable_layernorm: bool = True
layernorm_type: str = 'layernorm'
epsilon: float = 1e-6
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = WeightInit.Constant(1.0)
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes_1: Tuple[str, ...] = ()
kernel_axes_2: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_axes_1: Tuple[str, ...] = ()
bias_axes_2: Tuple[str, ...] = ()
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ('relu',)
intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = ()
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
major_sharding_type: MajorShardingType = MajorShardingType.SINGLE
def setup(self) -> None:
"""setup"""
super().setup()
ln_mlp_cls = partial(
flax_LayerNormMLP,
intermediate_dim=self.intermediate_dim,
enable_layernorm=self.enable_layernorm,
layernorm_type=self.layernorm_type,
epsilon=self.epsilon,
zero_centered_gamma=self.zero_centered_gamma,
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
ln_bias_init=TransformerEngineBaseLayer.generate_params_init(
"ln_bias", self.ln_bias_init),
ln_bias_axes=self.ln_bias_axes,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
kernel_axes_1=self.kernel_axes_1,
kernel_axes_2=self.kernel_axes_2,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes_1=self.bias_axes_1,
bias_axes_2=self.bias_axes_2,
return_layernorm_output=self.return_layernorm_output,
activations=self.activations,
intermediate_dropout_rate=self.intermediate_dropout_rate,
intermediate_hidden_dropout_dims=self.intermediate_hidden_dropout_dims,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
major_sharding_type=self.major_sharding_type)
self.create_layer("ln_mlp", ln_mlp_cls)
def __call__(self, x: JTensor, deterministic: bool = False) -> JTensor:
"""__call__"""
return self.ln_mlp(x, deterministic)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Praxis Modules related Transformer
"""
from functools import partial
from typing import Optional, Sequence, Tuple
from praxis import pax_fiddle
from praxis.base_layer import WeightInit
from praxis.pytypes import JTensor
from .module import TransformerEngineBaseLayer
from ..flax.transformer import AttentionType, TransformerLayerType
from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention
from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases
from ..flax.transformer import TransformerLayer as flax_TransformerLayer
class RelativePositionBiases(TransformerEngineBaseLayer):
"""RelativePositionBiases"""
num_buckets: int = 32
max_distance: int = 128
num_attention_heads: int = 64
embedding_init: WeightInit = None
embedding_axes: Tuple[str, ...] = ()
@staticmethod
def generate_embedding_init(init, num_attention_heads, num_buckets):
"""generate_embedding_init"""
embedding_init = init
if embedding_init is None:
rb_stddev = (num_attention_heads * num_buckets)**-0.5
embedding_init = WeightInit.Gaussian(rb_stddev)
return embedding_init
def setup(self) -> None:
"""setup"""
super().setup()
embedding_init = RelativePositionBiases.generate_embedding_init(
self.embedding_init, self.num_attention_heads, self.num_buckets)
rpb_cls = partial(flax_RelativePositionBiases,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
num_attention_heads=self.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init),
embedding_axes=self.embedding_axes,
dtype=self.dtype)
self.create_layer("relative_position_bias", rpb_cls)
def __call__(self, q_seqlen: JTensor, k_seqlen: JTensor, bidirectional: bool = True) -> JTensor:
"""__call__"""
return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional)
class MultiHeadAttention(TransformerEngineBaseLayer):
"""MultiHeadAttention"""
head_dim: int = 64
num_heads: int = 16
dropout_rate: float = 0.
dropout_rng_name: str = 'dropout'
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
attn_type: AttentionType = AttentionType.PADDING
fuse_qkv: bool = True
transpose_batch_sequence: bool = True
scale_attn_logits: bool = False
scaled_query_init: bool = True
float32_logits: bool = False
def setup(self) -> None:
"""setup"""
super().setup()
mha_cls = partial(
flax_MultiHeadAttention,
dtype=self.dtype,
head_dim=self.head_dim,
num_heads=self.num_heads,
dropout_rate=self.dropout_rate,
dropout_rng_name=self.dropout_rng_name,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm,
attn_type=self.attn_type,
fuse_qkv=self.fuse_qkv,
transpose_batch_sequence=self.transpose_batch_sequence,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
float32_logits=self.float32_logits)
self.create_layer("multi_head_attn", mha_cls)
def __call__(self,
inputs_q: JTensor,
inputs_kv: JTensor,
mask: Optional[JTensor] = None,
bias: Optional[JTensor] = None,
*,
decode: bool = False,
deterministic: bool = False) -> JTensor:
"""__call__"""
return self.multi_head_attn(inputs_q,
inputs_kv,
mask,
bias,
decode=decode,
deterministic=deterministic)
class TransformerLayer(TransformerEngineBaseLayer):
"""TransformerLayer"""
hidden_size: int = 512
mlp_hidden_size: int = 2048
num_attention_heads: int = 8
layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
dropout_rng_name: str = 'dropout'
mlp_activations: Sequence[str] = ('relu',)
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER
enable_relative_embedding: bool = True
relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
drop_path: float = 0.0
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
def setup(self) -> None:
"""setup"""
super().setup()
relative_embedding_flax_module = None
if self.enable_relative_embedding and self.relative_embedding is not None:
assert self.relative_embedding.num_attention_heads == \
self.num_attention_heads, \
"TransformerLayer.relative_embedding.num_attention_heads shoule be" \
"the same as TransformerLayer.num_attention_heads."
embedding_init = RelativePositionBiases.generate_embedding_init(
self.relative_embedding.embedding_init, self.relative_embedding.num_attention_heads,
self.relative_embedding.num_buckets)
relative_embedding_flax_module = flax_RelativePositionBiases(
num_buckets=self.relative_embedding.num_buckets,
max_distance=self.relative_embedding.max_distance,
num_attention_heads=self.relative_embedding.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init),
embedding_axes=self.relative_embedding.embedding_axes,
dtype=self.relative_embedding.dtype)
transformerlayer_cls = partial(
flax_TransformerLayer,
dtype=self.dtype,
hidden_size=self.hidden_size,
mlp_hidden_size=self.mlp_hidden_size,
num_attention_heads=self.num_attention_heads,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
hidden_dropout=self.hidden_dropout,
hidden_dropout_dims=self.hidden_dropout_dims,
attention_dropout=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name,
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mha_kernel", self.params_init),
mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mlp_kernel", self.params_init),
mlp_activations=self.mlp_activations,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm,
float32_attention_logits=self.float32_attention_logits,
layer_type=self.layer_type,
enable_relative_embedding=self.enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
drop_path=self.drop_path,
fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init)
self.create_layer("transformerlayer", transformerlayer_cls)
def __call__(self,
inputs: JTensor,
encoded: JTensor = None,
attention_mask: JTensor = None,
encoder_decoder_mask: JTensor = None,
deterministic: bool = False,
decode: bool = False,
max_decode_length: bool = None) -> JTensor:
"""__call__"""
return self.transformerlayer(inputs, encoded, attention_mask, encoder_decoder_mask,
deterministic, decode, max_decode_length)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment