Unverified Commit 7f5e4cb9 authored by Kaixi Hou's avatar Kaixi Hou Committed by GitHub
Browse files

Remove the autocast_variable from TF-TE (#141)



Remove the autocast_variable
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>
parent ab44f050
......@@ -5,7 +5,6 @@
"""Top level Transformer Engine PyTorch modules"""
from typing import Union, Callable
from keras import backend, layers, initializers
from keras.mixed_precision import autocast_variable
import tensorflow as tf
import transformer_engine_tensorflow as tex
......@@ -56,8 +55,13 @@ def get_autocast_bias(dtype, bias_var, use_bias, use_fp8):
"""Get casted bias for fp8 gemm."""
if not use_bias:
return None
with autocast_variable.enable_auto_cast_variables(dtype):
# We need to pass the EagerTensor instead of Variable when calling into the
# pybind functions. So, we use value() for the explicit convertion.
bias = bias_var.value()
if dtype == "float16":
bias = tf.cast(bias, dtype)
if use_fp8 and bias.dtype == tf.float32:
bias = tf.cast(bias, dtype=tf.bfloat16)
return bias
......@@ -527,11 +531,12 @@ class Dense(TransformerEngineBaseModule, layers.Layer):
"""Prep fwd+bwd non-fp8 matmul."""
@tf.custom_gradient
def non_fp8_matmul_func(x):
# Use value() to convert from Variable to EagerTensor
# We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
kernel_val = kernel_var.value()
bias = get_autocast_bias(
self._compute_dtype_object, bias_var, self.use_bias,
use_fp8=False,
self.compute_dtype, bias_var, self.use_bias, use_fp8=False,
)
output_dtype = self._compute_dtype_object
......@@ -577,11 +582,12 @@ class Dense(TransformerEngineBaseModule, layers.Layer):
@tf.custom_gradient
def fp8_matmul_func(x):
# Use value() to convert from Variable to EagerTensor
# We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
kernel_val = kernel_var.value()
bias = get_autocast_bias(
self._compute_dtype_object, bias_var, self.use_bias,
use_fp8=True,
self.compute_dtype, bias_var, self.use_bias, use_fp8=True,
)
if not override_linear_precision.wgrad:
......@@ -1017,7 +1023,9 @@ class LayerNormDense(TransformerEngineBaseModule, layers.Layer):
"""Prep fwd+bwd non-fp8 layernorm followed by matmul."""
@tf.custom_gradient
def non_fp8_layernorm_matmul_func(x):
# Use value() to convert from Variable to EagerTensor
# We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
kernel_val = kernel_var.value()
gamma_val = gamma_var.value()
beta_val = beta_var.value()
......@@ -1027,8 +1035,7 @@ class LayerNormDense(TransformerEngineBaseModule, layers.Layer):
)
bias = get_autocast_bias(
self._compute_dtype_object, bias_var, self.use_bias,
use_fp8=False,
self.compute_dtype, bias_var, self.use_bias, use_fp8=False,
)
output_dtype = self._compute_dtype_object
......@@ -1097,7 +1104,9 @@ class LayerNormDense(TransformerEngineBaseModule, layers.Layer):
@tf.custom_gradient
def fp8_layernorm_matmul_func(x):
# Use value() to convert from Variable to EagerTensor
# We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
kernel_val = kernel_var.value()
gamma_val = gamma_var.value()
beta_val = beta_var.value()
......@@ -1127,8 +1136,7 @@ class LayerNormDense(TransformerEngineBaseModule, layers.Layer):
)
bias = get_autocast_bias(
self._compute_dtype_object, bias_var, self.use_bias,
use_fp8=True,
self.compute_dtype, bias_var, self.use_bias, use_fp8=True,
)
weight_fp8, weight_t_fp8 = fp8_cast_transpose_fused_wrapper(
......@@ -1524,7 +1532,9 @@ class LayerNormMLP(TransformerEngineBaseModule, layers.Layer):
"""Prep fwd+bwd non-fp8 layernorm followed by mlp."""
@tf.custom_gradient
def non_fp8_layernorm_mlp_func(x):
# Use value() to convert from Variable to EagerTensor
# We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
fc1_kernel_val = fc1_kernel_var.value()
fc2_kernel_val = fc2_kernel_var.value()
gamma_val = gamma_var.value()
......@@ -1535,12 +1545,10 @@ class LayerNormMLP(TransformerEngineBaseModule, layers.Layer):
)
fc1_bias = get_autocast_bias(
self._compute_dtype_object, fc1_bias_var, use_bias=True,
use_fp8=False,
self.compute_dtype, fc1_bias_var, use_bias=True, use_fp8=False,
)
fc2_bias = get_autocast_bias(
self._compute_dtype_object, fc2_bias_var, self.use_bias,
use_fp8=False,
self.compute_dtype, fc2_bias_var, self.use_bias, use_fp8=False,
)
output_dtype = self._compute_dtype_object
......@@ -1652,7 +1660,9 @@ class LayerNormMLP(TransformerEngineBaseModule, layers.Layer):
@tf.custom_gradient
def fp8_layernorm_mlp_func(x):
# Use value() to convert from Variable to EagerTensor
# We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
fc1_kernel_val = fc1_kernel_var.value()
fc2_kernel_val = fc2_kernel_var.value()
gamma_val = gamma_var.value()
......@@ -1683,12 +1693,10 @@ class LayerNormMLP(TransformerEngineBaseModule, layers.Layer):
)
fc1_bias = get_autocast_bias(
self._compute_dtype_object, fc1_bias_var, use_bias=True,
use_fp8=True,
self.compute_dtype, fc1_bias_var, use_bias=True, use_fp8=True,
)
fc2_bias = get_autocast_bias(
self._compute_dtype_object, fc2_bias_var, self.use_bias,
use_fp8=True,
self.compute_dtype, fc2_bias_var, self.use_bias, use_fp8=True,
)
fc1_weight_fp8, fc1_weight_t_fp8 = fp8_cast_transpose_fused_wrapper(
......
......@@ -9,7 +9,6 @@ from typing import Callable, Optional, Tuple, Union
import os
from keras import backend, layers, initializers
from keras.mixed_precision import autocast_variable
import tensorflow as tf
from transformer_engine.tensorflow.module import (
......@@ -770,9 +769,7 @@ class TransformerLayer(tf.keras.Model): # pylint: disable=too-few-public-methods
bias_dropout_add_func = get_bias_dropout_add(training)
# Bias dropout add.
# The autocast scope is used to enforce the correct dtype for the bias.
with autocast_variable.enable_auto_cast_variables(
self._compute_dtype_object):
attention_bias = tf.cast(attention_bias, dtype=self.compute_dtype)
if self.drop_path is None:
bda_output = bias_dropout_add_func(
attention_output,
......@@ -809,11 +806,7 @@ class TransformerLayer(tf.keras.Model): # pylint: disable=too-few-public-methods
attention_output, attention_bias = inter_attention_outputs
residual = bda_output
# The autocast scope is used to enforce the correct dtype for the
# bias.
with autocast_variable.enable_auto_cast_variables(
self._compute_dtype_object
):
attention_bias = tf.cast(attention_bias, dtype=self.compute_dtype)
bda_output = bias_dropout_add_func(
attention_output,
attention_bias,
......@@ -833,9 +826,7 @@ class TransformerLayer(tf.keras.Model): # pylint: disable=too-few-public-methods
residual = bda_output
# Bias dropout add.
# The autocast scope is used to enforce the correct dtype for the bias.
with autocast_variable.enable_auto_cast_variables(
self._compute_dtype_object):
mlp_bias = tf.cast(mlp_bias, dtype=self.compute_dtype)
if self.drop_path is None:
output = bias_dropout_add_func(
mlp_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment