Unverified Commit b7acb6e1 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

Add TensorFlow module and extensions (#85)



* Add tensorflow build

Improve build instructions

Fix pybind enum usage

Fix Python_EXECUTABLE cmake var

Move scale_inv calculations to FW
Signed-off-by: default avatarTrevor Morris <tmorris@nvidia.com>
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Apply clang-format
Signed-off-by: default avatarTrevor Morris <tmorris@nvidia.com>
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Format python files
Signed-off-by: default avatarTrevor Morris <tmorris@nvidia.com>
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Add TF build CI
Signed-off-by: default avatarTrevor Morris <tmorris@nvidia.com>
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Lint checks
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Another round of lint checks
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Fix TF image tag
Signed-off-by: default avatarTrevor Morris <tmorris@nvidia.com>
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Use the existing recipe file
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Add license claim blocks
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Fix a bug about bias dtype conversion
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Add mnist example and cleanup old examples
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Autopep8 the tests
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Autopep8 the examples
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Add example in Readme
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Add unit tests and linting for TensorFlow
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add causal mask for non-fused case
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Fix the mismatched TF vs TE masks
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Addressing CI tests
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Run lint test
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Add missing import
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Skip fp8 tests for pre-Hopper GPUs
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Remove non-pytest tests
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Fix license
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTrevor Morris <tmorris@nvidia.com>
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarkaixih <kaixih@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 0963b288
......@@ -31,3 +31,27 @@ jobs:
run: pip install --no-cache-dir wheelhouse/transformer_engine*.whl
- name: 'Sanity check'
run: python tests/pytorch/test_sanity_import.py
TensorFlow:
name: 'TensorFlow build'
runs-on: ubuntu-latest
container:
image: nvcr.io/nvidia/tensorflow:23.02-tf2-py3
options: --user root
steps:
- name: 'Checkout'
uses: actions/checkout@v3
- name: 'Build'
run: |
apt-get update && apt-get install -y ninja-build pybind11-dev
mkdir -p wheelhouse && \
NVTE_FRAMEWORK=tensorflow pip wheel -w wheelhouse . -v
- name: 'Upload wheel'
uses: actions/upload-artifact@v3
with:
name: te_wheel_tf
path: wheelhouse/transformer_engine*.whl
retention-days: 7
- name: 'Install'
run: pip install --no-cache-dir wheelhouse/transformer_engine*.whl
- name: 'Sanity check'
run: python tests/tensorflow/test_sanity_import.py
......@@ -108,6 +108,37 @@ JAX
# Update FP8 metas
other_variables = te.update_fp8_metas(other_grads)
TensorFlow
^^^^^^^^^^
.. code-block:: python
import tensorflow as tf
import transformer_engine.tensorflow as te
from transformer_engine.common import recipe
# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048
# Initialize model and inputs.
model = te.Dense(out_features, use_bias=True)
inp = tf.random.normal((hidden_size, in_features))
optimizer = tf.keras.optimizers.Adam(0.001)
# Create FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)
with tf.GradientTape(persistent=True) as tape:
# Enables autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
out = model(inp)
loss = tf.reduce_sum(out)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
Highlights
----------
......@@ -131,6 +162,12 @@ Transformer Engine comes preinstalled in the pyTorch container on
From source
^^^^^^^^^^^
First, install the prequisites.
.. code-block:: bash
apt-get install ninja-build pybind11-dev
Clone the repository and inside it type:
.. code-block:: bash
......@@ -139,6 +176,19 @@ Clone the repository and inside it type:
NVTE_FRAMEWORK=pytorch pip install . # Building with pyTorch only.
NVTE_FRAMEWORK=jax pip install . # Building with JAX only.
You can also specify which framework bindings to build. The default is pytorch only.
.. code-block:: bash
# Build with TensorFlow bindings
NVTE_FRAMEWORK=tensorflow pip install .
# Build with Jax bindings
NVTE_FRAMEWORK=jax pip install .
# Build with all bindings (Pytorch, TF, Jax)
NVTE_FRAMEWORK=all pip install .
User Guide
----------
......
# Basic MNIST Example with optional FP8
```bash
python mnist.py
python mnist.py --use-te # Linear layers from TransformerEngine
python mnist.py --use-fp8 # FP8 + TransformerEngine for Linear layers
```
# Benchmark A Basic Transformer Layer
```bash
python transformer_layer.py
```
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import argparse
import tensorflow as tf
import tensorflow_datasets as tfds
import transformer_engine.tensorflow as te
class MNIST(tf.keras.Model):
def __init__(self, use_te=False):
super().__init__()
self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
if use_te:
self.dense1 = te.Dense(128, kernel_initializer='glorot_uniform',
bias_initializer='zeros')
else:
self.dense1 = tf.keras.layers.Dense(128, activation=None)
self.relu = tf.keras.layers.ReLU()
self.dense2 = tf.keras.layers.Dense(10)
def call(self, x):
x = self.flatten(x)
x = self.dense1(x)
x = self.relu(x)
y = self.dense2(x)
return y
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
def train_step(inputs, model, optimizer, use_fp8, fp8_recipe=None):
x, labels = inputs
with tf.GradientTape(persistent=True) as tape:
with te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
y = model(x, training=True)
loss = loss_func(labels, y)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss
val_loss = tf.keras.metrics.Mean(name='val_loss', dtype=tf.float32)
val_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
def valid_step(inputs, model):
x, labels = inputs
predictions = model(x, training=False)
loss = loss_func(labels, predictions)
val_loss.update_state(loss)
val_accuracy.update_state(labels, predictions)
def main():
# Training settings
parser = argparse.ArgumentParser(description="Tensorflow MNIST Example")
parser.add_argument(
"--batch-size",
type=int,
default=128,
metavar="N",
help="input batch size for training (default: 128)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=128,
metavar="N",
help="input batch size for testing (default: 128)",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
metavar="N",
help="number of epochs to train (default: 10)",
)
parser.add_argument(
"--lr",
type=float,
default=0.001,
metavar="LR",
help="learning rate (default: 0.001)",
)
parser.add_argument(
"--seed", type=int, default=12, metavar="S",
help="random seed (default: 12)"
)
parser.add_argument(
"--use-fp8", action="store_true", default=False,
help="Use FP8 for inference and training without recalibration"
)
parser.add_argument(
"--use-te", action="store_true", default=False,
help="Use Transformer Engine"
)
args = parser.parse_args()
batch_size = args.batch_size
test_batch_size = args.test_batch_size
num_epoch = args.epochs
tf.random.set_seed(args.seed)
tf.keras.utils.set_random_seed(args.seed)
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
nstep_per_epoch = len(ds_train) // batch_size
nstep_per_valid = len(ds_test) // test_batch_size
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255., label
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(batch_size)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(batch_size)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)
model = MNIST(use_te=(args.use_te or args.use_fp8))
optimizer = tf.keras.optimizers.Adam(args.lr)
fp8_recipe = te.DelayedScaling(
margin=0, interval=1, fp8_format=te.Format.HYBRID,
amax_compute_algo='max', amax_history_len=16)
for i in range(num_epoch):
ds_train_iter = iter(ds_train)
for _ in range(nstep_per_epoch):
inputs = next(ds_train_iter)
_ = train_step(inputs, model, optimizer, use_fp8=args.use_fp8,
fp8_recipe=fp8_recipe)
val_loss.reset_states()
val_accuracy.reset_states()
ds_test_iter = iter(ds_test)
for _ in range(nstep_per_valid):
inputs = next(ds_test_iter)
valid_step(inputs, model)
print("epoch-{} loss: {} - accuracy: {}".format(
i, val_loss.result(), val_accuracy.result()))
if __name__ == "__main__":
main()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from transformer_engine.tensorflow import Format, DelayedScaling
import argparse
import tensorflow as tf
import time
import transformer_engine.tensorflow as te
from keras import layers
from keras import Model
from typing import Optional
parser = argparse.ArgumentParser(description="Benchmark TransformerLayer.")
parser.add_argument(
'-t', '--type', type=int, default=0,
help="""Pick TE implementation (0:all|1:TF-fp16|2:TE-fp16|3:TE-fp8)""")
args, _ = parser.parse_known_args()
tl_type = args.type
tf.keras.mixed_precision.set_global_policy('mixed_float16')
dropout_rate = 0.0
class DotProductAttention(tf.keras.Model):
"""Attention operation in Transformer layer
"""
def __init__(
self,
num_attention_heads: int,
kv_channels: int,
attention_dropout: float,
):
super().__init__()
self.projection_size = kv_channels * num_attention_heads
self.hidden_size_per_attention_head = float(kv_channels)
self.norm_factor = tf.math.sqrt(self.hidden_size_per_attention_head)
self.dropout = layers.Dropout(attention_dropout)
if self.dropout.dtype_policy.name == 'mixed_float16':
self.norm_factor = tf.cast(self.norm_factor, dtype=tf.float16)
def masked_softmax(
self,
inp: tf.Tensor,
mask: Optional[tf.Tensor]
) -> tf.Tensor:
if mask is not None:
inp = tf.where(mask, -10000.0, inp)
return tf.nn.softmax(inp, axis=-1)
def call(
self,
query: tf.Tensor,
key: tf.Tensor,
value: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
) -> tf.Tensor:
b = query.shape[1]
np = query.shape[2]
sq = query.shape[0]
sk = key.shape[0]
hn = value.shape[3]
# [sq, b, np, hn] -> [sq, b * np, hn]
query = tf.reshape(query, shape=(sq, b * np, hn))
# [sk, b, np, hn] -> [sk, b * np, hn]
key = tf.reshape(key, shape=(sk, b * np, hn))
bmm1 = tf.matmul(tf.transpose(query, perm=(1, 0, 2)),
tf.transpose(key, perm=(1, 2, 0))) / self.norm_factor
# change view to [b, np, sq, sk]
attention_scores = tf.reshape(bmm1, shape=(b, np, sq, sk))
attention_probs = self.masked_softmax(attention_scores, attention_mask)
attention_probs = self.dropout(attention_probs)
# change view [sk, b * np, hn]
value = tf.reshape(value, shape=(sk, b * np, hn))
# change view [b * np, sq, sk]
attention_probs = tf.reshape(attention_probs, shape=(b * np, sq, sk))
# matmul: [b * np, sq, hn]
context = tf.matmul(attention_probs,
tf.transpose(value, perm=(1, 0, 2)))
# change view [b, np, sq, hn]
context = tf.reshape(context, shape=(b, np, sq, hn))
# [b, np, sq, hn] --> [sq, b, np, hn]
context = tf.transpose(context, perm=(2, 0, 1, 3))
# [sq, b, np, hn] --> [sq, b, hp]
context = tf.reshape(context, shape=(sq, b, self.projection_size))
return context
class BasicMLP(tf.keras.Model):
"""Feed-forward network in Transformer layer
"""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
):
super().__init__()
self.linear1 = layers.Dense(ffn_hidden_size, use_bias=True)
self.linear2 = layers.Dense(hidden_size, use_bias=True)
def call(
self,
x: tf.Tensor
) -> tf.Tensor:
x = self.linear1(x)
x = tf.nn.gelu(x, approximate=True)
x = self.linear2(x)
return x
class BasicTransformer(tf.keras.Model):
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: int = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
self.ln1 = layers.LayerNormalization(epsilon=layernorm_eps)
self.qkv_projection = layers.Dense(3 * hidden_size, use_bias=True)
self.attention = DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
)
self.projection = layers.Dense(hidden_size, use_bias=True)
self.dropout = layers.Dropout(hidden_dropout)
self.ln2 = layers.LayerNormalization(epsilon=layernorm_eps)
self.mlp = BasicMLP(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
)
def call(
self,
x: tf.Tensor,
attention_mask: tf.Tensor,
) -> tf.Tensor:
res = x
x = self.ln1(x)
# Fused QKV projection
qkv = self.qkv_projection(x)
qkv_shape = qkv.shape
qkv = tf.reshape(qkv,
shape=(qkv_shape[0], qkv_shape[1],
self.num_attention_heads, 3 * self.kv_channels))
q, k, v = tf.split(qkv, 3, axis=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout(x)
x = res + x
res = x
x = self.ln2(x)
x = self.mlp(x)
return x + res
class FusedTETransformer(tf.keras.Model):
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: int = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
self.ln_qkv = te.LayerNormDense(3 * hidden_size, epsilon=layernorm_eps,
use_bias=True)
self.attention = DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
)
self.projection = te.Dense(hidden_size, use_bias=True)
self.dropout = layers.Dropout(hidden_dropout)
self.ln_mlp = te.LayerNormMLP(ffn_hidden_size, hidden_size,
epsilon=layernorm_eps, use_bias=True,
return_layernorm_output=False)
def call(
self,
x: tf.Tensor,
attention_mask: tf.Tensor,
) -> tf.Tensor:
res = x
qkv = self.ln_qkv(x)
# Split qkv into query, key and value
qkv_shape = qkv.shape
qkv = tf.reshape(qkv,
shape=(qkv_shape[0], qkv_shape[1],
self.num_attention_heads, 3 * self.kv_channels))
q, k, v = tf.split(qkv, 3, axis=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout(x)
x = res + x
res = x
x = self.ln_mlp(x)
return x + res
# Layer configuration
hidden_size = 4096
sequence_length = 2048
batch_size = 4
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = tf.float32
def speedometer(
model: tf.keras.Model,
input: tf.Tensor,
forward_kwargs: dict = {},
fp8_autocast_kwargs: Optional[dict] = None,
timing_iters: int = 50,
warmup_iters: int = 50,
) -> None:
"""Measure average run time for a TF model
Performs forward and backward passes.
"""
if fp8_autocast_kwargs is None:
fp8_autocast_kwargs = {"enabled": False}
p = tf.constant(0.) # Create small tensor to force GPU resync
# Warmup runs
for _ in range(warmup_iters):
with tf.GradientTape(persistent=True) as tape:
tape.watch(input)
with te.fp8_autocast(**fp8_autocast_kwargs):
output = model(input, **forward_kwargs)
loss = tf.reduce_sum(output)
dx, dvars = tape.gradient(loss, [input, model.variables])
(p + 1.).numpy() # Sync the GPU
# Timing runs
start = time.time()
for _ in range(timing_iters):
with tf.GradientTape(persistent=True) as tape:
tape.watch(input)
with te.fp8_autocast(**fp8_autocast_kwargs):
output = model(input, **forward_kwargs)
loss = tf.reduce_sum(output)
dx, dvars = tape.gradient(loss, [input, model.variables])
(p + 1.).numpy() # Sync the GPU
end = time.time()
elapsed_time = (end - start) / timing_iters * 1000
print(f"Mean time: {elapsed_time} ms")
tf.random.set_seed(12)
tf.keras.utils.set_random_seed(1)
# Synthetic data
x = tf.random.normal(shape=(sequence_length, batch_size, hidden_size),
dtype=dtype)
basic_transformer = BasicTransformer(
hidden_size,
ffn_hidden_size,
num_attention_heads,
attention_dropout=dropout_rate,
hidden_dropout=dropout_rate,
)
y = basic_transformer(x, attention_mask=None)
if tl_type in (0, 1):
print("Running in the native TF:")
speedometer(
basic_transformer,
x,
forward_kwargs={"attention_mask": None, "training": True},
)
te_transformer = FusedTETransformer(
hidden_size,
ffn_hidden_size,
num_attention_heads,
attention_dropout=dropout_rate,
hidden_dropout=dropout_rate,
)
fp8_recipe = DelayedScaling(margin=0, interval=1, fp8_format=Format.HYBRID,
amax_compute_algo='max', amax_history_len=16)
# Run once to build the variables.
te_transformer(x, attention_mask=None)
# Sync the variables with the reference.
for v0, v1 in zip(basic_transformer.variables, te_transformer.variables):
v1.assign(v0)
tf.debugging.assert_near(v1, v0)
y_te = te_transformer(x, attention_mask=None)
if tl_type in (0, 2):
print("Running in the TE:")
speedometer(
te_transformer,
x,
forward_kwargs={"attention_mask": None, "training": True},
fp8_autocast_kwargs={"enabled": False, "fp8_recipe": None},
)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
y_te = te_transformer(x, attention_mask=None)
if tl_type in (0, 3):
print("Running in the TE with fp8:")
speedometer(
te_transformer,
x,
forward_kwargs={"attention_mask": None, "training": True},
fp8_autocast_kwargs={"enabled": True, "fp8_recipe": fp8_recipe},
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Stop searching for additional config files.
set noparent
# Limit line length.
linelength=100
# Ignore the following errors.
filter=-build/include_subdir
filter=-build/namespaces
filter=-readability/todo
filter=-build/header_guard
filter=-build/include
[MASTER]
extension-pkg-whitelist=transformer_engine_tensorflow
disable=too-many-locals,
invalid-name,
too-many-arguments,
abstract-method,
arguments-differ,
too-many-instance-attributes,
unsubscriptable-object,
import-outside-toplevel,
too-many-statements,
import-error,
too-many-lines,
use-maxsplit-arg,
protected-access,
pointless-string-statement,
cyclic-import,
duplicate-code,
no-member,
attribute-defined-outside-init,
global-statement,
too-many-branches,
global-variable-not-assigned,
redefined-argument-from-local
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
pip install cpplint==1.6.0 pylint==2.13.5
if [ -z "${PYTHON_ONLY}" ]
then
cp $TE_PATH/qa/L0_tensorflow_lint/CPPLINT.cfg $TE_PATH
cd $TE_PATH
echo "Checking common API headers"
cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include
echo "Checking C++ files"
cpplint --recursive --exclude=transformer_engine/common/include transformer_engine/common
cpplint --recursive transformer_engine/tensorflow
fi
if [ -z "${CPP_ONLY}" ]
then
cp $TE_PATH/qa/L0_tensorflow_lint/pylintrc $TE_PATH
cd $TE_PATH
echo "Checking Python files"
pylint --recursive=y transformer_engine/common transformer_engine/tensorflow
fi
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -xe
: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/tensorflow
......@@ -95,6 +95,7 @@ supported_frameworks = {
"all": all_sources,
"pytorch": pytorch_sources,
"jax": None, # JAX use transformer_engine/CMakeLists.txt
"tensorflow": None, # tensorflow use transformer_engine/CMakeLists.txt
}
framework = os.environ.get("NVTE_FRAMEWORK", "pytorch")
......@@ -166,6 +167,17 @@ class JaxBuilder(FrameworkBuilderBase):
def run(self, extensions):
print("Building jax extensions!")
class TensorFlowBuilder(FrameworkBuilderBase):
def cmake_flags(self):
return ["-DENABLE_TENSORFLOW=ON"]
def run(self, extensions):
print("Building TensorFlow extensions!")
@staticmethod
def install_requires():
return ["pydantic",]
ext_modules = []
dlfw_builder_funcs = []
......@@ -196,6 +208,9 @@ if framework in ("all", "pytorch"):
if framework in ("all", "jax"):
dlfw_builder_funcs.append(JaxBuilder)
if framework in ("all", "tensorflow"):
dlfw_builder_funcs.append(TensorFlowBuilder)
dlfw_install_requires = []
for builder in dlfw_builder_funcs:
dlfw_install_requires = dlfw_install_requires + builder.install_requires()
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for the cpp extensions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import transformer_engine # pylint: disable=unused-import
import transformer_engine_tensorflow as tex
import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from transformer_engine.tensorflow import TE_DType
from transformer_engine.tensorflow import get_stream_id
class ExtensionsTest(test.TestCase):
@test_util.run_gpu_only
def testCastFp8(self):
if not tf.test.is_gpu_available(True, (9, 0)):
self.skipTest('Fp8 requires Hopper+ GPU')
input_shape = (16, 32)
x = tf.random.uniform(input_shape)
scale, amax, scale_inv = tf.ones([]), tf.zeros([]), tf.ones([])
offset = 0
for fp8_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
stream_id = get_stream_id()
x_fp8 = tex.cast_to_fp8(
x, scale, fp8_dtype, amax, scale_inv, offset, stream_id)
y = tex.cast_from_fp8(
x_fp8, scale_inv, fp8_dtype, TE_DType[x.dtype],
offset, stream_id)
self.assertAllClose(y, x, rtol=0.1, atol=0.01)
@test_util.run_gpu_only
def testTransposeFp8(self):
stream_id = get_stream_id()
x = tf.constant(np.random.uniform(-128, 127, (16, 32)), dtype=tf.int8)
y = tex.fp8_transpose(x, tex.DType.kFloat8E4M3, stream_id)
y_ref = tf.transpose(x, [1, 0])
self.assertAllEqual(y, y_ref)
@test_util.run_gpu_only
def testMatmulFp8(self):
if not tf.test.is_gpu_available(True, (9, 0)):
self.skipTest('Fp8 requires Hopper+ GPU')
stream_id = get_stream_id()
fp8_dtype = tex.DType.kFloat8E4M3
out_dtype = tex.DType.kFloat32
a = tf.random.uniform([32, 16])
a_scale, a_amax, a_scale_inv = tf.ones([]), tf.zeros([]), tf.ones([])
a_offset = 0
a_casted = tex.cast_to_fp8(a, a_scale, fp8_dtype, a_amax, a_scale_inv,
a_offset, stream_id)
b = tf.random.uniform([16, 16])
b_scale, b_amax, b_scale_inv = tf.ones([]), tf.zeros([]), tf.ones([])
b_offset = 0
b_casted = tex.cast_to_fp8(b, b_scale, fp8_dtype, b_amax, b_scale_inv,
b_offset, stream_id)
use_bias = False
bias = tf.zeros(())
workspace = tf.zeros([33_554_432], dtype=tf.int8)
# CublasLt inside tex.te_gemm assumes inputs are column major.
# Mathematically, A@B=C is equivalent to B^T@A^T=C^T, where X^T is the
# transpose of X. Actually, if we view X^T is the column major of X, we
# don't need any explict transpose.
# Note, for fp8 matmul, the first matrix has to be in transposed format.
d = tex.te_gemm(b_casted, b_scale_inv, fp8_dtype, b_offset, a_casted,
a_scale_inv, fp8_dtype, a_offset, workspace, use_bias,
bias, False, None, True, False, False, False, False,
out_dtype, stream_id)
# We assume b is in transposed format (see above). So we transpose it
# back to apply the ordinary row-major matmul.
bt = tf.transpose(b)
d_ref = tf.matmul(a, bt)
self.assertAllClose(d, d_ref, rtol=0.1, atol=0.01)
@test_util.run_gpu_only
def testLayerNormFwdFp8(self):
if not tf.test.is_gpu_available(True, (9, 0)):
self.skipTest('Fp8 requires Hopper+ GPU')
stream_id = get_stream_id()
fp8_dtype = tex.DType.kFloat8E4M3
N, H = (16, 32)
eps = 1e-3
x = tf.random.uniform((N, H))
gamma = tf.random.uniform((H,))
beta = tf.random.uniform((H,))
offset = 0
scale, amax, scale_inv = tf.ones([]), tf.zeros([]), tf.ones([])
y_ref, mu_ref, rsigma_ref = tex.layernorm_fwd(
x, gamma, beta, eps, stream_id)
y_fp8, mu, rsigma = tex.layernorm_fwd_fp8(
x, gamma, beta, eps, scale, fp8_dtype, amax, scale_inv, offset,
stream_id)
y = tex.cast_from_fp8(y_fp8, scale_inv, fp8_dtype, TE_DType[x.dtype],
offset, stream_id)
self.assertAllClose(y, y_ref, rtol=0.1, atol=0.01)
self.assertAllClose(mu, mu_ref)
self.assertAllClose(rsigma, rsigma_ref)
@test_util.run_gpu_only
def testGeluForwardFp8(self):
if not tf.test.is_gpu_available(True, (9, 0)):
self.skipTest('Fp8 requires Hopper+ GPU')
stream_id = get_stream_id()
fp8_dtype = tex.DType.kFloat8E4M3
M, N = (16, 32)
x = tf.random.uniform((M, N))
offset = 0
scale, amax, scale_inv = tf.ones([]), tf.zeros([]), tf.ones([])
y_ref = tf.nn.gelu(x, approximate=True)
y_fp8 = tex.te_gelu(x, scale, fp8_dtype, amax,
scale_inv, offset, stream_id)
y = tex.cast_from_fp8(y_fp8, scale_inv, fp8_dtype, TE_DType[x.dtype],
offset, stream_id)
self.assertAllClose(y, y_ref, rtol=0.1, atol=0.01)
@test_util.run_gpu_only
def testGeluForward(self):
if not tf.test.is_gpu_available(True, (9, 0)):
self.skipTest('Fp8 requires Hopper+ GPU')
stream_id = get_stream_id()
M, N = (16, 32)
x = tf.random.uniform((M, N))
y_ref = tf.nn.gelu(x, approximate=True)
y = tex.te_gelu(x, None, TE_DType[x.dtype], None, None, 0, stream_id)
self.assertAllClose(y, y_ref, rtol=0.00001, atol=0.00001)
@test_util.run_gpu_only
def testGeluBackwardFp8(self):
if not tf.test.is_gpu_available(True, (9, 0)):
self.skipTest('Fp8 requires Hopper+ GPU')
stream_id = get_stream_id()
fp8_dtype = tex.DType.kFloat8E5M2
M, K, N = (16, 32, 32)
x = tf.random.uniform((M, K))
bias = tf.random.uniform((K, ))
dy = tf.random.uniform((M, K))
offset = 0
scale, amax, scale_inv = tf.ones([]), tf.zeros([]), tf.ones([])
with tf.GradientTape(persistent=True) as tape:
tape.watch([x, bias])
x_gelu = tf.nn.bias_add(x, bias)
y = tf.nn.gelu(x_gelu, approximate=True)
loss = y * dy
dgelu_ref, dbias_ref = tape.gradient(loss, [x_gelu, bias])
dbias, dgelu_c, dgelu_t = tex.fp8_fused_cast_transpose_bgrad_dgelu(
dy, x_gelu, scale, fp8_dtype, amax, scale_inv, offset, stream_id)
dgelu = tex.cast_from_fp8(
dgelu_c, scale_inv, fp8_dtype, TE_DType[x.dtype], offset, stream_id)
self.assertAllClose(dgelu, dgelu_ref, rtol=0.1, atol=0.01)
self.assertAllClose(dbias, dbias_ref)
self.assertAllEqual(dgelu_c, tf.transpose(dgelu_t, [1, 0]))
@test_util.run_gpu_only
def testScaledUpperTriangMaskedSoftmaxFwd(self):
stream_id = get_stream_id()
B, F = (16, 32)
scale = 0.8
x = tf.random.uniform((B, F, F), dtype=tf.half)
mask_operator = tf.linalg.LinearOperatorLowerTriangular(
tf.ones((F, F), dtype=tf.bool))
mask = mask_operator.to_dense()
mask_output = tf.where(mask, scale * x, -10000.0)
y_ref = tf.nn.softmax(mask_output, axis=-1)
y = tex.scaled_upper_triang_masked_softmax_forward(x, scale, stream_id)
self.assertAllClose(y, y_ref, rtol=0.001, atol=0.001)
@test_util.run_gpu_only
def testScaledUpperTriangMaskedSoftmaxBwd(self):
stream_id = get_stream_id()
B, F = (16, 32)
scale = 0.8
x = tf.random.uniform((B, F, F), dtype=tf.half)
dy = tf.random.uniform((B, F, F), dtype=tf.half)
mask_operator = tf.linalg.LinearOperatorLowerTriangular(
tf.ones((F, F), dtype=tf.bool))
mask = mask_operator.to_dense()
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
mask_output = tf.where(mask, scale * x, -10000.0)
y = tf.nn.softmax(mask_output, axis=-1)
y = tf.cast(y, dtype=tf.half)
loss = y * dy
dx_ref = tape.gradient(loss, x)
dx = tex.scaled_upper_triang_masked_softmax_backward(
dy, y, scale, stream_id)
self.assertAllClose(dx, dx_ref, rtol=0.001, atol=0.001)
@test_util.run_gpu_only
def testScaledMaskedSoftmaxFwd(self):
stream_id = get_stream_id()
B, N, F = (16, 4, 32)
scale = 0.8
x = tf.random.uniform((B, N, F, F), dtype=tf.half)
# In NVTE, if the mask is true, the corresponding value is zero.
# Whereas, TF does the opposite. In addition, NVTE requires the mask has
# the same num of dims as the input.
mask = tf.reshape(x[0, 0] > 0.3, shape=(1, 1, F, F))
flipped_mask = x[0, 0] <= 0.3
y_ref = tf.keras.layers.Softmax(axis=-1)(scale * x, flipped_mask)
y = tex.scaled_masked_softmax_forward(x, mask, scale, stream_id)
self.assertAllClose(y, y_ref, rtol=0.001, atol=0.001)
@test_util.run_gpu_only
def testScaledMaskedSoftmaxBwd(self):
stream_id = get_stream_id()
B, N, F = (16, 4, 32)
scale = 0.8
x = tf.random.uniform((B, N, F, F), dtype=tf.half)
dy = tf.random.uniform((B, N, F, F), dtype=tf.half)
mask = tf.reshape(x[0, 0] > 0.3, shape=(1, 1, F, F))
flipped_mask = x[0, 0] <= 0.3
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
y = tf.keras.layers.Softmax(axis=-1)(scale * x, flipped_mask)
y = tf.cast(y, dtype=tf.half)
loss = y * dy
dx_ref = tape.gradient(loss, x)
dx = tex.scaled_masked_softmax_backward(dy, y, scale, stream_id)
self.assertAllClose(dx, dx_ref, rtol=0.001, atol=0.001)
@test_util.run_gpu_only
def testScaledSoftmaxFwd(self):
stream_id = get_stream_id()
B, N, F = (16, 4, 32)
scale = 0.8
x = tf.random.uniform((B, N, F, F), dtype=tf.half)
y_ref = tf.keras.layers.Softmax(axis=-1)(scale * x)
y = tex.scaled_softmax_forward(x, scale, stream_id)
self.assertAllClose(y, y_ref, rtol=0.001, atol=0.001)
@test_util.run_gpu_only
def testScaledSoftmaxBwd(self):
stream_id = get_stream_id()
B, N, F = (16, 4, 32)
scale = 0.8
x = tf.random.uniform((B, N, F, F), dtype=tf.half)
dy = tf.random.uniform((B, N, F, F), dtype=tf.half)
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
y = tf.keras.layers.Softmax(axis=-1)(scale * x)
y = tf.cast(y, tf.half)
loss = y * dy
dx_ref = tape.gradient(loss, x)
dx = tex.scaled_softmax_backward(dy, y, scale, stream_id)
self.assertAllClose(dx, dx_ref, rtol=0.001, atol=0.001)
if __name__ == '__main__':
test.main()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for the fp8 layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import transformer_engine.tensorflow as te
from itertools import product
from tensorflow.keras import initializers, layers
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from transformer_engine.tensorflow import (
Dense,
DelayedScaling,
Format,
LayerNorm,
LayerNormDense,
LayerNormMLP,
)
def get_fp8_recipe(override_wgrad=False):
fp8_recipe = DelayedScaling(
margin=0, interval=1, fp8_format=Format.HYBRID,
amax_compute_algo='max', amax_history_len=3,
override_linear_precision=(False, False, override_wgrad))
return fp8_recipe
def compute_scale(amax, scale, fp8_max, margin):
"""Default function to convert amax to scaling factor."""
exp = tf.math.floor(tf.experimental.numpy.log2(fp8_max / amax)) - margin
sf = tf.math.round(tf.math.pow(2., tf.math.abs(exp)))
sf = tf.where(amax > 0.0, sf, scale)
sf = tf.where(tf.math.is_finite(amax), sf, scale)
sf = tf.where(exp < 0, 1.0 / sf, sf)
return sf
def update_scale(amax_h, scale, fp8_meta, is_fwd):
key = "fp8_max_fwd" if is_fwd else "fp8_max_bwd"
amax = tf.reduce_max(amax_h, axis=0)
fp8_max = fp8_meta[key]
margin = fp8_meta["recipe"].margin
scale = compute_scale(amax, scale, fp8_max, margin)
scale_inv = 1. / scale
return scale, scale_inv
def roll_and_update(amax_h, update):
amax_h = tf.roll(amax_h, shift=-1, axis=0)
amax_h = tf.tensor_scatter_nd_update(amax_h, [[0]], [update])
return amax_h
# This function is to recompute the results of layernorm bprop.
def get_adjusted_layernorm_dx(x, ln_dy, init):
assert x.shape == ln_dy.shape
ln_layer = layers.LayerNormalization(
gamma_initializer=init,
beta_initializer=init,
)
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
y = ln_layer(x)
loss = y * ln_dy
ln_dx, (ln_dgamma, ln_dbeta) = tape.gradient(loss, [x, ln_layer.variables])
return ln_dx, ln_dgamma, ln_dbeta
class LayersTest(test.TestCase):
@test_util.run_gpu_only
def testDenseFwd(self):
B, M, K, N = 4, 8, 16, 32
init = initializers.RandomUniform(minval=0., maxval=1.)
dense_kwargs = {
"units": N,
"use_bias": True,
"kernel_initializer": init,
"bias_initializer": init,
}
dense_ref = layers.Dense(**dense_kwargs)
dense = Dense(**dense_kwargs)
x = tf.random.uniform((B, M, K))
fp8_recipe = get_fp8_recipe()
for use_fp8 in [False, True]:
if use_fp8 and not tf.test.is_gpu_available(True, (9, 0)):
continue
y_ref = dense_ref(x)
with te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
y = dense(x)
# The TE higher precision calls use the bias fusion, so they are not
# exactly same with the TF calls.
atol, rtol = (0.01, 0.05) if use_fp8 else (1e-3, 1e-3)
self.assertAllClose(y, y_ref, rtol, atol, msg=f"use_fp8={use_fp8}")
@test_util.run_gpu_only
def testDenseBwd(self):
B, M, K, N = 4, 8, 16, 32
init = initializers.RandomUniform(minval=0., maxval=1.)
dense_kwargs = {
"units": N,
"use_bias": True,
"kernel_initializer": init,
"bias_initializer": init,
}
dense_ref = layers.Dense(**dense_kwargs)
dense = Dense(**dense_kwargs)
dy = tf.random.uniform((B, M, N))
def _train_step(x, model, use_fp8=False, fp8_recipe=None):
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
with te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
y = model(x, training=True)
loss = y * tf.cast(dy, y.dtype)
dx, (dw, db) = tape.gradient(loss, [x, model.trainable_variables])
return dx, dw, db
x = tf.random.uniform((B, M, K))
for use_fp8, use_override in product([False, True], repeat=2):
if use_fp8 and not tf.test.is_gpu_available(True, (9, 0)):
continue
recipe = get_fp8_recipe(use_override)
dx_ref, dw_ref, db_ref = _train_step(x, dense_ref)
dx, dw, db = _train_step(
x, dense, use_fp8=use_fp8, fp8_recipe=recipe)
assert_msg = f"use_fp8={use_fp8},use_override={use_override}"
atol, rtol = (0.01, 0.05) if use_fp8 else (1e-6, 1e-6)
self.assertAllClose(dx, dx_ref, rtol, atol, msg="dx," + assert_msg)
self.assertAllClose(db, db_ref, rtol, atol, msg="db," + assert_msg)
atol, rtol = \
(0.01, 0.05) if use_fp8 and not use_override else (1e-6, 1e-6)
self.assertAllClose(dw, dw_ref, rtol, atol, msg="dw," + assert_msg)
@test_util.run_gpu_only
def testDenseSkipWeight(self):
B, M, K, N = 4, 8, 16, 32
init = initializers.RandomUniform(minval=0., maxval=1.)
dense_kwargs = {
"units": N,
"use_bias": True,
"kernel_initializer": init,
"bias_initializer": init,
}
dense_ref = layers.Dense(**dense_kwargs)
dense = Dense(**dense_kwargs, skip_weight_param_allocation=True)
x = tf.random.uniform((B, M, K))
fp8_recipe = get_fp8_recipe()
for use_fp8 in [False, True]:
if use_fp8 and not tf.test.is_gpu_available(True, (9, 0)):
continue
y_ref = dense_ref(x)
with te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
y = dense(x, kernel=dense_ref.kernel, bias=dense_ref.bias)
atol, rtol = (0.01, 0.05) if use_fp8 else (1e-3, 1e-3)
self.assertAllClose(y, y_ref, rtol, atol, msg=f"use_fp8={use_fp8}")
@test_util.run_gpu_only
def testDenseBookkeeping(self):
if not tf.test.is_gpu_available(True, (9, 0)):
self.skipTest('Fp8 requires Hopper+ GPU')
M, K, N = 16, 16, 32
init = initializers.RandomNormal(mean=0., stddev=1.)
dense = Dense(N, kernel_initializer=init)
fp8_recipe = get_fp8_recipe()
def _train_step(x, dy):
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
y = dense(x, training=True)
loss = y * tf.cast(dy, y.dtype)
dx, dw = tape.gradient(loss, [x, dense.kernel])
return dx, dw
scale_fwd_ref = tf.ones((2,))
scale_bwd_ref = tf.ones((1,))
scale_inv_fwd_ref = 1. / scale_fwd_ref
scale_inv_bwd_ref = 1. / scale_bwd_ref
amax_h_fwd_ref = tf.zeros((fp8_recipe.amax_history_len, 2))
amax_h_bwd_ref = tf.zeros((fp8_recipe.amax_history_len, 1))
atol, rtol = 0.001, 0.001
for step in range(5):
x = tf.random.normal((M, K))
dy = tf.random.normal((M, N))
dx, dw = _train_step(x, dy)
amax_x = tf.math.reduce_max(tf.math.abs(x))
amax_w = tf.math.reduce_max(tf.math.abs(dense.kernel))
amax_dy = tf.math.reduce_max(tf.math.abs(dy))
amax_h_fwd_ref = roll_and_update(amax_h_fwd_ref, [amax_x, amax_w])
amax_h_bwd_ref = roll_and_update(amax_h_bwd_ref, [amax_dy])
amax_h_fwd = dense.fp8_meta['scaling_fwd']['amax_history']
amax_h_bwd = dense.fp8_meta['scaling_bwd']['amax_history']
scale_fwd = dense.fp8_meta['scaling_fwd']['scale']
scale_bwd = dense.fp8_meta['scaling_bwd']['scale']
scale_inv_fwd = dense.fp8_meta['scaling_fwd']['scale_inv']
scale_inv_bwd = dense.fp8_meta['scaling_bwd']['scale_inv']
self.assertAllClose(
amax_h_fwd, amax_h_fwd_ref, rtol, atol, msg="amax_history_fwd")
self.assertAllClose(
amax_h_bwd, amax_h_bwd_ref, rtol, atol, msg="amax_history_bwd")
self.assertAllClose(scale_fwd, scale_fwd_ref,
rtol, atol, msg="scale_fwd")
self.assertAllClose(scale_bwd, scale_bwd_ref,
rtol, atol, msg="scale_bwd")
self.assertAllClose(
scale_inv_fwd, scale_inv_fwd_ref, rtol, atol,
msg="scale_inv_fwd")
self.assertAllClose(
scale_inv_bwd, scale_inv_bwd_ref, rtol, atol,
msg="scale_inv_bwd")
scale_fwd_ref, scale_inv_fwd_ref = update_scale(
amax_h_fwd_ref, scale_fwd_ref, dense.fp8_meta, is_fwd=True)
scale_bwd_ref, scale_inv_bwd_ref = update_scale(
amax_h_bwd_ref, scale_bwd_ref, dense.fp8_meta, is_fwd=False)
# Apply an update to the kernel to mimic the gradient descent.
dense.kernel.assign_add(tf.cast(dw, tf.float32) * 0.1)
@test_util.run_gpu_only
def testLayerNormFwd(self):
B, M, N = 4, 16, 32
init = initializers.RandomNormal(mean=0., stddev=1.)
# The keras layer norm actually uses fp32 computation in mixed precision
# mode. So, for better comparison, we use fp32 in both reference and
# target layers.
ln_kwargs = {
"gamma_initializer": init,
"beta_initializer": init,
"dtype": 'float32',
}
ln_ref = layers.LayerNormalization(**ln_kwargs)
ln = LayerNorm(**ln_kwargs)
x = tf.random.normal((B, M, N))
y_ref = ln_ref(x)
y = ln(x)
self.assertAllClose(y, y_ref, msg="fwd_layer_norm:y")
@test_util.run_gpu_only
def testLayerNormBwd(self):
B, M, N = 4, 16, 32
init = initializers.RandomNormal(mean=0., stddev=1.)
ln_kwargs = {
"gamma_initializer": init,
"beta_initializer": init,
"dtype": 'float32',
}
ln_ref = layers.LayerNormalization(**ln_kwargs)
ln = LayerNorm(**ln_kwargs)
dy = tf.random.uniform((B, M, N))
def _train_step(x, model):
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
y = model(x, training=True)
loss = y * tf.cast(dy, y.dtype)
dx, (dg, dB) = tape.gradient(loss, [x, model.trainable_variables])
return dx, dg, dB
x = tf.random.uniform((B, M, N))
dx_ref, dg_ref, dB_ref = _train_step(x, ln_ref)
dx, dg, dB = _train_step(x, ln)
self.assertAllClose(dx, dx_ref, msg="bwd_layer_norm:dx")
self.assertAllClose(dB, dB_ref, msg="bwd_layer_norm:dbeta")
self.assertAllClose(dg, dg_ref, msg="bwd_layer_norm:dgamma")
@test_util.run_gpu_only
def testLayerNormDenseFwd(self):
B, M, K, N = 4, 8, 16, 32
init = initializers.RandomUniform(minval=0., maxval=1.)
ln_kwargs = {
"gamma_initializer": init,
"beta_initializer": init,
}
dense_kwargs = {
"units": N,
"use_bias": True,
"kernel_initializer": init,
"bias_initializer": init,
}
ln_ref = layers.LayerNormalization(**ln_kwargs)
dense_ref = layers.Dense(**dense_kwargs)
x = tf.random.uniform((B, M, K))
fp8_recipe = get_fp8_recipe()
for use_fp8, output_ln in product([False, True], repeat=2):
if use_fp8 and not tf.test.is_gpu_available(True, (9, 0)):
continue
ln_dense = LayerNormDense(
**ln_kwargs,
**dense_kwargs,
return_layernorm_output=output_ln,
)
y_ln_ref = ln_ref(x)
y_ref = dense_ref(y_ln_ref)
with te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
ys = ln_dense(x)
if output_ln:
y, y_ln = ys
else:
y = ys
assert_msg = f"use_fp8={use_fp8},output_ln={output_ln}"
atol, rtol = (0.01, 0.1) if use_fp8 else (1e-3, 1e-3)
self.assertAllClose(y, y_ref, rtol, atol, msg="y," + assert_msg)
if output_ln:
self.assertAllClose(
y_ln, y_ln_ref, rtol, atol, msg="y_ln," + assert_msg)
@test_util.run_gpu_only
def testLayerNormDenseBwd(self):
B, M, K, N = 4, 8, 16, 32
init = initializers.RandomUniform(minval=0., maxval=.1)
dy = tf.random.uniform((B, M, N), minval=0., maxval=1.)
x = tf.random.uniform((B, M, K), minval=0., maxval=1.)
ln_kwargs = {
"gamma_initializer": init,
"beta_initializer": init,
}
dense_kwargs = {
"units": N,
"use_bias": True,
"kernel_initializer": init,
"bias_initializer": init,
}
ln_ref = layers.LayerNormalization(**ln_kwargs)
dense_ref = layers.Dense(**dense_kwargs)
ln_dense = LayerNormDense(**ln_kwargs, **dense_kwargs)
def _train_step(x, model, use_fp8=False, fp8_recipe=None):
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
with te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
y = model(x, training=True)
loss = y * tf.cast(dy, y.dtype)
dx, (dg, dB, dw, db) = tape.gradient(
loss, [x, model.trainable_variables])
return dx, dg, dB, dw, db
def _train_step_ref(x):
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
t = ln_ref(x)
y = dense_ref(t)
loss = y * tf.cast(dy, y.dtype)
var_list = ln_ref.variables + dense_ref.variables
dx, dt, (dg, dB, dw, db) = tape.gradient(loss, [x, t, var_list])
return dx, dt, dg, dB, dw, db
for use_fp8, use_override in product([False, True], repeat=2):
if use_fp8 and not tf.test.is_gpu_available(True, (9, 0)):
continue
recipe = get_fp8_recipe(use_override)
dx_ref, ln_dy_ref, dg_ref, dB_ref, dw_ref, db_ref = _train_step_ref(
x)
dx, dg, dB, dw, db = _train_step(
x, ln_dense, use_fp8=use_fp8, fp8_recipe=recipe)
assert_msg = f"use_fp8={use_fp8},use_override={use_override}"
atol, rtol = (0.01, 0.1) if use_fp8 else (1e-3, 1e-3)
self.assertAllClose(db, db_ref, rtol, atol,
msg="dbias," + assert_msg)
self.assertAllClose(dw, dw_ref, rtol, atol,
msg="dkernel," + assert_msg)
atol, rtol = (0.1, 0.1) if use_fp8 else (1e-2, 1e-2)
self.assertAllClose(dx, dx_ref, rtol, atol,
msg="ln_dx," + assert_msg)
self.assertAllClose(dg, dg_ref, rtol, atol,
msg="dgamma," + assert_msg)
self.assertAllClose(dB, dB_ref, rtol, atol,
msg="dbeta," + assert_msg)
@test_util.run_gpu_only
def testLayerNormDenseSkipWeight(self):
B, M, K, N = 4, 8, 16, 32
init = initializers.RandomUniform(minval=0., maxval=1.)
ln_kwargs = {
"gamma_initializer": init,
"beta_initializer": init,
}
dense_kwargs = {
"units": N,
"use_bias": True,
"kernel_initializer": init,
"bias_initializer": init,
}
ln_ref = layers.LayerNormalization(**ln_kwargs)
dense_ref = layers.Dense(**dense_kwargs)
ln_dense = LayerNormDense(
**ln_kwargs,
**dense_kwargs,
skip_weight_param_allocation=True,
)
x = tf.random.uniform((B, M, K))
fp8_recipe = get_fp8_recipe()
for use_fp8 in [False, True]:
if use_fp8 and not tf.test.is_gpu_available(True, (9, 0)):
continue
y_ref = dense_ref(ln_ref(x))
with te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
y = ln_dense(x, kernel=dense_ref.kernel, bias=dense_ref.bias)
atol, rtol = (0.01, 0.1) if use_fp8 else (1e-3, 1e-3)
self.assertAllClose(y, y_ref, rtol, atol, msg=f"use_fp8={use_fp8}")
@test_util.run_gpu_only
def testLayerNormMLPFwd(self):
B, M, K, N, O = 4, 8, 16, 32, 64
init = initializers.RandomUniform(minval=0., maxval=1.)
ln_kwargs = {
"gamma_initializer": init,
"beta_initializer": init,
}
dense_common_kwargs = {
"use_bias": True,
"kernel_initializer": init,
"bias_initializer": init,
}
ln_ref = layers.LayerNormalization(**ln_kwargs)
dense1_ref = layers.Dense(**dense_common_kwargs, units=N)
dense2_ref = layers.Dense(**dense_common_kwargs, units=O)
x = tf.random.uniform((B, M, K))
fp8_recipe = get_fp8_recipe()
for use_fp8, output_ln in product([False, True], repeat=2):
if use_fp8 and not tf.test.is_gpu_available(True, (9, 0)):
continue
ln_mlp = LayerNormMLP(
**ln_kwargs,
**dense_common_kwargs,
units=N,
ffn_units=O,
ffn_kernel_initializer=init,
return_layernorm_output=output_ln,
)
y_ln_ref = ln_ref(x)
y_dense1_ref = dense1_ref(y_ln_ref)
y_gelu_ref = tf.nn.gelu(y_dense1_ref, approximate=True)
y_ref = dense2_ref(y_gelu_ref)
with te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
ys = ln_mlp(x)
if output_ln:
y, y_ln = ys
else:
y = ys
assert_msg = f"use_fp8={use_fp8},output_ln={output_ln}"
atol, rtol = (0.01, 0.1) if use_fp8 else (1e-3, 2e-3)
self.assertAllClose(y, y_ref, rtol, atol, msg="y," + assert_msg)
if output_ln:
self.assertAllClose(
y_ln, y_ln_ref, rtol, atol, msg="y_ln," + assert_msg)
@test_util.run_gpu_only
def testLayerNormMLPBwd(self):
B, M, K, N, O = 4, 8, 16, 32, 64
init = initializers.RandomUniform(minval=0., maxval=.1)
dy = tf.random.uniform((B, M, O), minval=0., maxval=1.)
x = tf.random.uniform((B, M, K), minval=0., maxval=1.)
ln_kwargs = {
"gamma_initializer": init,
"beta_initializer": init,
}
dense_common_kwargs = {
"use_bias": True,
"kernel_initializer": init,
"bias_initializer": init,
}
ln_ref = layers.LayerNormalization(**ln_kwargs)
dense1_ref = layers.Dense(**dense_common_kwargs, units=N)
dense2_ref = layers.Dense(**dense_common_kwargs, units=O)
ln_mlp = LayerNormMLP(
**ln_kwargs,
**dense_common_kwargs,
units=N,
ffn_units=O,
ffn_kernel_initializer=init,
)
def _train_step(x, model, use_fp8=False, fp8_recipe=None):
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
with te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
y = model(x, training=True)
loss = y * tf.cast(dy, y.dtype)
dx, (dg, dB, dw1, db1, dw2, db2) = tape.gradient(
loss, [x, model.trainable_variables])
return dx, dg, dB, dw1, db1, dw2, db2
def _train_step_ref(x):
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
t = ln_ref(x)
y_gelu = tf.nn.gelu(dense1_ref(t), approximate=True)
y = dense2_ref(y_gelu)
loss = y * tf.cast(dy, y.dtype)
var_list = ln_ref.variables + dense1_ref.variables + \
dense2_ref.variables
dx, dt, (dg, dB, dw1, db1, dw2, db2) = tape.gradient(
loss, [x, t, var_list])
return dx, dt, dg, dB, dw1, db1, dw2, db2
for use_fp8, use_override in product([False, True], repeat=2):
if use_fp8 and not tf.test.is_gpu_available(True, (9, 0)):
continue
recipe = get_fp8_recipe(use_override)
dx_ref, ln_dy_ref, dg_ref, dB_ref, dw1_ref, db1_ref, dw2_ref, \
db2_ref = _train_step_ref(x)
dx, dg, dB, dw1, db1, dw2, db2 = _train_step(
x, ln_mlp, use_fp8=use_fp8, fp8_recipe=recipe)
assert_msg = f"use_fp8={use_fp8},use_override={use_override}"
atol, rtol = (0.01, 0.1) if use_fp8 else (1e-3, 1e-3)
self.assertAllClose(
db2, db2_ref, rtol, atol, msg="fc2_dbias," + assert_msg)
self.assertAllClose(
dw2, dw2_ref, rtol, atol, msg="fc2_dw," + assert_msg)
self.assertAllClose(
db1, db1_ref, rtol, atol, msg="fc1_dbias," + assert_msg)
self.assertAllClose(
dw1, dw1_ref, rtol, atol, msg="fc1_dw," + assert_msg)
atol, rtol = (0.1, 0.1) if use_fp8 else (1e-2, 1e-2)
self.assertAllClose(dx, dx_ref, rtol, atol,
msg="ln_dx," + assert_msg)
self.assertAllClose(dg, dg_ref, rtol, atol,
msg="dgamma," + assert_msg)
self.assertAllClose(dB, dB_ref, rtol, atol,
msg="dbeta," + assert_msg)
if __name__ == '__main__':
tf.keras.mixed_precision.set_global_policy('mixed_float16')
test.main()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for the MHA layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import transformer_engine.tensorflow as te
from tensorflow.keras.layers import EinsumDense
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from transformer_engine.tensorflow import (
DelayedScaling,
Format,
MultiHeadAttention,
)
def train_step(dy, x_q, x_kv, x_mask, model, attn_type, use_fp8=False,
fp8_recipe=None):
with tf.GradientTape(persistent=True) as tape:
tape.watch(x_q)
if attn_type == 'cross':
tape.watch(x_kv)
with te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
# The MHA won't apply the bias addition for the last projection but
# return the bias. So, we conduct the bias addition here at the end.
y, b = model(x_q, x_mask, x_kv, training=True)
y = y + tf.cast(b, y.dtype)
loss = y * tf.cast(dy, dtype=y.dtype)
xs = [x_q]
if attn_type == 'cross':
xs.append(x_kv)
dxs, dvars = tape.gradient(loss, [xs, model.trainable_variables])
return y, dxs, dvars
class MultiHeadAttentionKeras(tf.keras.Model):
def __init__(self, hidden_size, num_heads, attention_type, init_method):
super(MultiHeadAttentionKeras, self).__init__()
assert hidden_size % num_heads == 0
assert attention_type in ('self', 'cross')
self.num_heads = num_heads
self.hidden_size = hidden_size
self.depth = hidden_size // self.num_heads
self.attention_type = attention_type
# Einsum symbols:
# F=seq_q, T=seq_kv, B=batches, H=hidden_states, D=hidden_size,
# N=num_heads, E=depth
if attention_type == 'self':
self.QKV = EinsumDense('FBH,HD->FBD',
output_shape=(None, 3 * hidden_size),
bias_axes='D',
kernel_initializer=init_method)
else:
self.Q = EinsumDense('FBH,HD->FBD',
output_shape=(None, hidden_size),
bias_axes='D',
kernel_initializer=init_method)
self.KV = EinsumDense('TBH,HD->TBD',
output_shape=(None, 2 * hidden_size),
bias_axes='D',
kernel_initializer=init_method)
# The bias in the projection layer will be applied separately outside
# the MHA. So, we disable the bias in the Einsum but handle the bias at
# the end.
self.dense = EinsumDense('FBNE,NED->FBD',
output_shape=(None, hidden_size),
bias_axes=None,
kernel_initializer=init_method)
b_init = tf.zeros_initializer()
self.dense_bias = tf.Variable(
initial_value=b_init(shape=(hidden_size,),
dtype="float32"),
trainable=True,
)
def __call__(self, q_input, mask=None, kv_input=None, training=None):
if self.attention_type == 'self':
# [F, B, 3 * D]
qkv = self.QKV(q_input)
# [F, B, N, 3 * E]
qkv = tf.reshape(
qkv, (*qkv.shape[: -1],
self.num_heads, 3 * self.depth))
# 3 * [F, B, N, E]
q, k, v = tf.split(qkv, num_or_size_splits=3, axis=-1)
else:
# [F, B, D]
q = self.Q(q_input)
# [F, B, N, E]
q = tf.reshape(q, (*q.shape[:-1], self.num_heads, self.depth))
# [F, B, 2 * D]
kv = self.KV(kv_input)
# [F, B, N, 2 * E]
kv = tf.reshape(
kv, (*kv.shape[: -1],
self.num_heads, 2 * self.depth))
# 2 * [F, B, N, E]
k, v = tf.split(kv, num_or_size_splits=2, axis=-1)
dk = tf.cast(tf.shape(k)[-1], self._compute_dtype_object)
matmul_qk = tf.einsum('FBNE,TBNE->BNFT', q, k)
scaled_attn_logits = matmul_qk / tf.math.sqrt(dk)
if mask is not None:
scaled_attn_logits = tf.where(mask, scaled_attn_logits, -10000.0)
# [B, N, F, T]
attention_weights = tf.nn.softmax(scaled_attn_logits, axis=-1)
# [B, N, F, E]
scaled_attention = tf.einsum('BNFT,TBNE->BNFE', attention_weights, v)
# [F, B, N, E]
scaled_attention = tf.transpose(scaled_attention, perm=(2, 0, 1, 3))
# [F, B, D]
output = self.dense(scaled_attention)
return output, self.dense_bias
class MHATest(test.TestCase):
@test_util.run_gpu_only
def testMHAForward(self):
use_fp8 = tf.test.is_gpu_available(True, (9, 0))
batches, seq_q, seq_kv, hidden_states = 16, 32, 32, 64
num_heads, depth = 4, 16
hidden_size = num_heads * depth
q_shape = (seq_q, batches, hidden_states)
kv_shape = (seq_kv, batches, hidden_states)
init = tf.keras.initializers.RandomUniform(minval=0., maxval=.1)
x_q = tf.random.uniform(q_shape, minval=0., maxval=.1)
x_kv = tf.random.uniform(kv_shape, minval=0., maxval=.1)
for attn_type in ('self', 'cross'):
for use_mask in (True, False):
mha_einsum = MultiHeadAttentionKeras(
hidden_size, num_heads, attn_type, init)
# The attention mask type needs to be `padding`, which will use
# provided mask. Alternatively, the `causal` will ignore the
# provided mask and use a upper triangular mask.
mha = MultiHeadAttention(
hidden_size=hidden_size,
num_attention_heads=num_heads,
kv_channels=depth,
attention_dropout=0.0,
attention_softmax_in_fp32=True,
init_method=init,
output_layer_init_method=init,
input_layernorm=False,
attention_type=attn_type,
attn_mask_type='padding',
)
x_mask = tf.random.uniform(
(seq_q, seq_kv)) > 0.5 if use_mask else None
y_ref, y_b_ref = mha_einsum(x_q, x_mask, x_kv)
fp8_recipe = DelayedScaling(
margin=0, interval=1, fp8_format=Format.HYBRID,
amax_compute_algo='max', amax_history_len=3)
with te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
y, y_b = mha(x_q, x_mask, x_kv)
self.assertAllClose(y, y_ref, rtol=0.01, atol=0.01, msg='y')
self.assertAllClose(y_b, y_b_ref, msg='y_bias')
@test_util.run_gpu_only
def testMHABackward(self):
use_fp8 = tf.test.is_gpu_available(True, (9, 0))
batches, seq_q, seq_kv, hidden_states = 4, 8, 8, 32
num_heads, depth = 4, 8
hidden_size = num_heads * depth
q_shape = (seq_q, batches, hidden_states)
kv_shape = (seq_kv, batches, hidden_states)
out_shape = (seq_q, batches, hidden_size)
init = tf.keras.initializers.RandomUniform(minval=0., maxval=.1)
x_q = tf.random.uniform(q_shape, minval=0., maxval=.1)
x_kv = tf.random.uniform(kv_shape, minval=0., maxval=.1)
dy = tf.random.uniform(out_shape, minval=0., maxval=1.)
for attn_type in ('self', 'cross'):
for use_mask in (False, True):
mha_einsum = MultiHeadAttentionKeras(
hidden_size, num_heads, attn_type, init)
mha = MultiHeadAttention(
hidden_size=hidden_size,
num_attention_heads=num_heads,
kv_channels=depth,
attention_dropout=0.0,
attention_softmax_in_fp32=True,
init_method=init,
output_layer_init_method=init,
input_layernorm=False,
attention_type=attn_type,
attn_mask_type='padding',
)
x_mask = tf.random.uniform(
(seq_q, seq_kv)) > 0.5 if use_mask else None
y_ref, dxs_ref, dvars_ref = train_step(
dy, x_q, x_kv, x_mask, mha_einsum, attn_type)
fp8_recipe = DelayedScaling(
margin=0, interval=1, fp8_format=Format.HYBRID,
amax_compute_algo='max', amax_history_len=3)
y, dxs, dvars = train_step(
dy, x_q, x_kv, x_mask, mha, attn_type, use_fp8, fp8_recipe)
for dx, dx_ref in zip(dxs, dxs_ref):
self.assertAllClose(
dx, dx_ref, rtol=0.1, atol=0.1, msg='dx')
if attn_type == 'cross':
# The variable lists are:
# [q_w, kv_w, q_b, kv_b, proj_w, proj_b] (target)
# [q_w, q_b, kv_w, kv_b, proj_w, proj_b] (reference)
self.assertEqual(len(dvars), 6)
self.assertEqual(len(dvars), len(dvars_ref))
dws = [dvars[i] for i in [0, 1, 4]]
dws_ref = [dvars_ref[i] for i in [0, 2, 4]]
dbs = [dvars[i] for i in [2, 3, 5]]
dbs_ref = [dvars_ref[i] for i in [1, 3, 5]]
else:
# The variable lists are:
# [qkv_w, qkv_b, proj_w, proj_b] (target)
# [qkv_w, qkv_b, proj_w, proj_b] (reference)
self.assertEqual(len(dvars), 4)
self.assertEqual(len(dvars), len(dvars_ref))
dws = [dvars[i] for i in [0, 2]]
dws_ref = [dvars_ref[i] for i in [0, 2]]
dbs = [dvars[i] for i in [1, 3]]
dbs_ref = [dvars_ref[i] for i in [1, 3]]
for dv, dv_ref in zip(dws, dws_ref):
self.assertAllClose(
dv, tf.reshape(dv_ref, dv.shape),
rtol=0.1, atol=0.1, msg='dkernel')
for dv, dv_ref in zip(dbs, dbs_ref):
self.assertAllClose(dv, dv_ref, rtol=0.2,
atol=0.2, msg='dbias')
if __name__ == '__main__':
tf.keras.mixed_precision.set_global_policy('mixed_float16')
test.main()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
try:
import transformer_engine.tensorflow
te_imported = True
except:
te_imported = False
assert te_imported, 'transformer_engine import failed'
print("OK")
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for the Transformer layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import transformer_engine.tensorflow as te
from tensorflow.keras.layers import EinsumDense
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from transformer_engine.tensorflow import (
DelayedScaling,
Format,
TransformerLayer,
)
def train_step(dy, x, x_mask, x_dec, x_dec_mask, model, use_fp8=False,
fp8_recipe=None):
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
with te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
y = model(
hidden_states=x,
attention_mask=x_mask,
encoder_output=x_dec,
enc_dec_attn_mask=x_dec_mask,
training=True,
)
loss = y * tf.cast(dy, dtype=y.dtype)
dx, dvars = tape.gradient(loss, [x, model.trainable_variables])
return y, dx, dvars
class TransformerLayerTest(test.TestCase):
@test_util.run_gpu_only
def testTransformerSanity(self):
use_fp8 = tf.test.is_gpu_available(True, (9, 0))
# F=seq_len, B=batch, H=hidden_states, N=num_heads
F, B, H, N = 8, 4, 32, 2
# E=depth
E = H // N
# D=hidden_size
D = N * E
input_shape = (F, B, H)
output_shape = (F, B, D)
init = tf.keras.initializers.RandomUniform(minval=0., maxval=.1)
x = tf.random.uniform(input_shape, minval=0., maxval=.1)
x_dec = tf.random.uniform(input_shape, minval=0., maxval=10.)
dy = tf.random.uniform(output_shape, minval=0., maxval=.1)
transformer = TransformerLayer(
hidden_size=D,
ffn_hidden_size=D,
num_attention_heads=N,
layernorm_epsilon=1e-5,
hidden_dropout=0.01,
attention_dropout=0.0,
init_method=init,
output_layer_init_method=init,
layer_number=None,
kv_channels=None,
self_attn_mask_type="padding",
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=False,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="decoder",
drop_path_rate=0.1,
fuse_qkv_params=False,
)
fp8_recipe = DelayedScaling(
margin=0, interval=1, fp8_format=Format.HYBRID,
amax_compute_algo='max', amax_history_len=3)
y_ref, dx_ref, dvars_ref = train_step(
dy, x, None, x_dec, None, transformer, use_fp8=False)
y, dx, dvars = train_step(dy, x, None, x_dec, None, transformer,
use_fp8=use_fp8, fp8_recipe=fp8_recipe)
self.assertAllClose(y, y_ref, rtol=0.1, atol=0.01, msg="fwd-y")
self.assertAllClose(dx, dx_ref, rtol=0.5, atol=0.7, msg="bwd-dx")
self.assertEqual(len(dvars), len(dvars_ref))
dvs = []
for v, dv, dv_ref in zip(
transformer.trainable_variables, dvars, dvars_ref):
dvs.append((v.name, dv, dv_ref))
for v_name, dv, dv_ref in reversed(dvs):
# The range of these two biases are relatively large. So, we choose
# larger atols here.
if v_name == 'multi_head_attention/dense/bias:0':
self.assertAllClose(dv, dv_ref, rtol=.1,
atol=4., msg="bwd-" + v_name)
continue
if v_name == 'multi_head_attention/qkv_bias:0':
self.assertAllClose(dv, dv_ref, rtol=.1,
atol=2., msg="bwd-" + v_name)
continue
atol, rtol = (0.5, 0.6) if tf.reduce_max(
dv_ref) > 1. else (0.05, 0.05)
self.assertAllClose(dv, dv_ref, rtol=rtol,
atol=atol, msg="bwd-" + v_name)
if __name__ == '__main__':
tf.keras.mixed_precision.set_global_policy('mixed_float16')
test.main()
......@@ -32,3 +32,9 @@ if(ENABLE_JAX)
find_package(pybind11 CONFIG REQUIRED)
add_subdirectory(jax)
endif()
option(ENABLE_TENSORFLOW "Enable TensorFlow in the building workflow." OFF)
if(ENABLE_TENSORFLOW)
find_package(pybind11 CONFIG REQUIRED)
add_subdirectory(tensorflow)
endif()
......@@ -15,3 +15,8 @@ try:
from . import jax
except ImportError as e:
pass
try:
from . import tensorflow
except ImportError as e:
pass
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
pybind11_add_module(
transformer_engine_tensorflow
${CMAKE_CURRENT_SOURCE_DIR}/csrc/extensions.cu
)
add_library(
_get_stream SHARED
${CMAKE_CURRENT_SOURCE_DIR}/csrc/get_stream_op.cpp
)
# Includes
execute_process(COMMAND ${Python_EXECUTABLE} -c "import tensorflow as tf; print(tf.sysconfig.get_include())"
OUTPUT_VARIABLE Tensorflow_INCLUDE_DIRS OUTPUT_STRIP_TRAILING_WHITESPACE)
execute_process(COMMAND ${Python_EXECUTABLE} -c "import numpy as np; print(np.get_include())"
OUTPUT_VARIABLE Numpy_INCLUDE_DIRS OUTPUT_STRIP_TRAILING_WHITESPACE)
target_include_directories(transformer_engine_tensorflow PRIVATE
${Tensorflow_INCLUDE_DIRS}
${Tensorflow_INCLUDE_DIRS}/external/farmhash_archive/src
${Numpy_INCLUDE_DIRS})
target_include_directories(_get_stream PRIVATE ${Tensorflow_INCLUDE_DIRS})
# Libraries
execute_process(COMMAND ${Python_EXECUTABLE} -c "import tensorflow as tf; print(tf.__file__)"
OUTPUT_VARIABLE Tensorflow_LIB_PATH OUTPUT_STRIP_TRAILING_WHITESPACE)
get_filename_component(Tensorflow_LIB_PATH ${Tensorflow_LIB_PATH} DIRECTORY)
list(APPEND TF_LINKER_LIBS "${Tensorflow_LIB_PATH}/libtensorflow_framework.so.2")
list(APPEND TF_LINKER_LIBS "${Tensorflow_LIB_PATH}/python/_pywrap_tensorflow_internal.so")
target_link_libraries(
transformer_engine_tensorflow PRIVATE
${TF_LINKER_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt transformer_engine
)
target_link_libraries(_get_stream PRIVATE ${TF_LINKER_LIBS})
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer Engine bindings for Tensorflow"""
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import Format
from .constants import TE_DType
from .fp8 import fp8_autocast
from .module import Dense
from .module import LayerNorm
from .module import LayerNormDense
from .module import LayerNormMLP
from .module import get_stream_id
from .transformer import MultiHeadAttention
from .transformer import TransformerLayer
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Enums for e2e transformer"""
import tensorflow as tf
import transformer_engine_tensorflow as tex
"""
This is a map: tf.dtype -> int
Used for passing dtypes into cuda
extension. Has one to one mapping
with enum in transformer_engine.h
"""
TE_DType = {
tf.int8: tex.DType.kByte,
tf.int32: tex.DType.kInt32,
tf.float32: tex.DType.kFloat32,
tf.half: tex.DType.kFloat16,
tf.bfloat16: tex.DType.kBFloat16,
}
AttnMaskTypes = ("causal", "padding")
AttnTypes = ("self", "cross")
LayerTypes = ("encoder", "decoder")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment