Unverified Commit 7f5e4cb9 authored by Kaixi Hou's avatar Kaixi Hou Committed by GitHub
Browse files

Remove the autocast_variable from TF-TE (#141)



Remove the autocast_variable
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>
parent ab44f050
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
"""Top level Transformer Engine PyTorch modules""" """Top level Transformer Engine PyTorch modules"""
from typing import Union, Callable from typing import Union, Callable
from keras import backend, layers, initializers from keras import backend, layers, initializers
from keras.mixed_precision import autocast_variable
import tensorflow as tf import tensorflow as tf
import transformer_engine_tensorflow as tex import transformer_engine_tensorflow as tex
...@@ -56,8 +55,13 @@ def get_autocast_bias(dtype, bias_var, use_bias, use_fp8): ...@@ -56,8 +55,13 @@ def get_autocast_bias(dtype, bias_var, use_bias, use_fp8):
"""Get casted bias for fp8 gemm.""" """Get casted bias for fp8 gemm."""
if not use_bias: if not use_bias:
return None return None
with autocast_variable.enable_auto_cast_variables(dtype):
# We need to pass the EagerTensor instead of Variable when calling into the
# pybind functions. So, we use value() for the explicit convertion.
bias = bias_var.value() bias = bias_var.value()
if dtype == "float16":
bias = tf.cast(bias, dtype)
if use_fp8 and bias.dtype == tf.float32: if use_fp8 and bias.dtype == tf.float32:
bias = tf.cast(bias, dtype=tf.bfloat16) bias = tf.cast(bias, dtype=tf.bfloat16)
return bias return bias
...@@ -527,11 +531,12 @@ class Dense(TransformerEngineBaseModule, layers.Layer): ...@@ -527,11 +531,12 @@ class Dense(TransformerEngineBaseModule, layers.Layer):
"""Prep fwd+bwd non-fp8 matmul.""" """Prep fwd+bwd non-fp8 matmul."""
@tf.custom_gradient @tf.custom_gradient
def non_fp8_matmul_func(x): def non_fp8_matmul_func(x):
# Use value() to convert from Variable to EagerTensor # We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
kernel_val = kernel_var.value() kernel_val = kernel_var.value()
bias = get_autocast_bias( bias = get_autocast_bias(
self._compute_dtype_object, bias_var, self.use_bias, self.compute_dtype, bias_var, self.use_bias, use_fp8=False,
use_fp8=False,
) )
output_dtype = self._compute_dtype_object output_dtype = self._compute_dtype_object
...@@ -577,11 +582,12 @@ class Dense(TransformerEngineBaseModule, layers.Layer): ...@@ -577,11 +582,12 @@ class Dense(TransformerEngineBaseModule, layers.Layer):
@tf.custom_gradient @tf.custom_gradient
def fp8_matmul_func(x): def fp8_matmul_func(x):
# Use value() to convert from Variable to EagerTensor # We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
kernel_val = kernel_var.value() kernel_val = kernel_var.value()
bias = get_autocast_bias( bias = get_autocast_bias(
self._compute_dtype_object, bias_var, self.use_bias, self.compute_dtype, bias_var, self.use_bias, use_fp8=True,
use_fp8=True,
) )
if not override_linear_precision.wgrad: if not override_linear_precision.wgrad:
...@@ -1017,7 +1023,9 @@ class LayerNormDense(TransformerEngineBaseModule, layers.Layer): ...@@ -1017,7 +1023,9 @@ class LayerNormDense(TransformerEngineBaseModule, layers.Layer):
"""Prep fwd+bwd non-fp8 layernorm followed by matmul.""" """Prep fwd+bwd non-fp8 layernorm followed by matmul."""
@tf.custom_gradient @tf.custom_gradient
def non_fp8_layernorm_matmul_func(x): def non_fp8_layernorm_matmul_func(x):
# Use value() to convert from Variable to EagerTensor # We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
kernel_val = kernel_var.value() kernel_val = kernel_var.value()
gamma_val = gamma_var.value() gamma_val = gamma_var.value()
beta_val = beta_var.value() beta_val = beta_var.value()
...@@ -1027,8 +1035,7 @@ class LayerNormDense(TransformerEngineBaseModule, layers.Layer): ...@@ -1027,8 +1035,7 @@ class LayerNormDense(TransformerEngineBaseModule, layers.Layer):
) )
bias = get_autocast_bias( bias = get_autocast_bias(
self._compute_dtype_object, bias_var, self.use_bias, self.compute_dtype, bias_var, self.use_bias, use_fp8=False,
use_fp8=False,
) )
output_dtype = self._compute_dtype_object output_dtype = self._compute_dtype_object
...@@ -1097,7 +1104,9 @@ class LayerNormDense(TransformerEngineBaseModule, layers.Layer): ...@@ -1097,7 +1104,9 @@ class LayerNormDense(TransformerEngineBaseModule, layers.Layer):
@tf.custom_gradient @tf.custom_gradient
def fp8_layernorm_matmul_func(x): def fp8_layernorm_matmul_func(x):
# Use value() to convert from Variable to EagerTensor # We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
kernel_val = kernel_var.value() kernel_val = kernel_var.value()
gamma_val = gamma_var.value() gamma_val = gamma_var.value()
beta_val = beta_var.value() beta_val = beta_var.value()
...@@ -1127,8 +1136,7 @@ class LayerNormDense(TransformerEngineBaseModule, layers.Layer): ...@@ -1127,8 +1136,7 @@ class LayerNormDense(TransformerEngineBaseModule, layers.Layer):
) )
bias = get_autocast_bias( bias = get_autocast_bias(
self._compute_dtype_object, bias_var, self.use_bias, self.compute_dtype, bias_var, self.use_bias, use_fp8=True,
use_fp8=True,
) )
weight_fp8, weight_t_fp8 = fp8_cast_transpose_fused_wrapper( weight_fp8, weight_t_fp8 = fp8_cast_transpose_fused_wrapper(
...@@ -1524,7 +1532,9 @@ class LayerNormMLP(TransformerEngineBaseModule, layers.Layer): ...@@ -1524,7 +1532,9 @@ class LayerNormMLP(TransformerEngineBaseModule, layers.Layer):
"""Prep fwd+bwd non-fp8 layernorm followed by mlp.""" """Prep fwd+bwd non-fp8 layernorm followed by mlp."""
@tf.custom_gradient @tf.custom_gradient
def non_fp8_layernorm_mlp_func(x): def non_fp8_layernorm_mlp_func(x):
# Use value() to convert from Variable to EagerTensor # We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
fc1_kernel_val = fc1_kernel_var.value() fc1_kernel_val = fc1_kernel_var.value()
fc2_kernel_val = fc2_kernel_var.value() fc2_kernel_val = fc2_kernel_var.value()
gamma_val = gamma_var.value() gamma_val = gamma_var.value()
...@@ -1535,12 +1545,10 @@ class LayerNormMLP(TransformerEngineBaseModule, layers.Layer): ...@@ -1535,12 +1545,10 @@ class LayerNormMLP(TransformerEngineBaseModule, layers.Layer):
) )
fc1_bias = get_autocast_bias( fc1_bias = get_autocast_bias(
self._compute_dtype_object, fc1_bias_var, use_bias=True, self.compute_dtype, fc1_bias_var, use_bias=True, use_fp8=False,
use_fp8=False,
) )
fc2_bias = get_autocast_bias( fc2_bias = get_autocast_bias(
self._compute_dtype_object, fc2_bias_var, self.use_bias, self.compute_dtype, fc2_bias_var, self.use_bias, use_fp8=False,
use_fp8=False,
) )
output_dtype = self._compute_dtype_object output_dtype = self._compute_dtype_object
...@@ -1652,7 +1660,9 @@ class LayerNormMLP(TransformerEngineBaseModule, layers.Layer): ...@@ -1652,7 +1660,9 @@ class LayerNormMLP(TransformerEngineBaseModule, layers.Layer):
@tf.custom_gradient @tf.custom_gradient
def fp8_layernorm_mlp_func(x): def fp8_layernorm_mlp_func(x):
# Use value() to convert from Variable to EagerTensor # We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
fc1_kernel_val = fc1_kernel_var.value() fc1_kernel_val = fc1_kernel_var.value()
fc2_kernel_val = fc2_kernel_var.value() fc2_kernel_val = fc2_kernel_var.value()
gamma_val = gamma_var.value() gamma_val = gamma_var.value()
...@@ -1683,12 +1693,10 @@ class LayerNormMLP(TransformerEngineBaseModule, layers.Layer): ...@@ -1683,12 +1693,10 @@ class LayerNormMLP(TransformerEngineBaseModule, layers.Layer):
) )
fc1_bias = get_autocast_bias( fc1_bias = get_autocast_bias(
self._compute_dtype_object, fc1_bias_var, use_bias=True, self.compute_dtype, fc1_bias_var, use_bias=True, use_fp8=True,
use_fp8=True,
) )
fc2_bias = get_autocast_bias( fc2_bias = get_autocast_bias(
self._compute_dtype_object, fc2_bias_var, self.use_bias, self.compute_dtype, fc2_bias_var, self.use_bias, use_fp8=True,
use_fp8=True,
) )
fc1_weight_fp8, fc1_weight_t_fp8 = fp8_cast_transpose_fused_wrapper( fc1_weight_fp8, fc1_weight_t_fp8 = fp8_cast_transpose_fused_wrapper(
......
...@@ -9,7 +9,6 @@ from typing import Callable, Optional, Tuple, Union ...@@ -9,7 +9,6 @@ from typing import Callable, Optional, Tuple, Union
import os import os
from keras import backend, layers, initializers from keras import backend, layers, initializers
from keras.mixed_precision import autocast_variable
import tensorflow as tf import tensorflow as tf
from transformer_engine.tensorflow.module import ( from transformer_engine.tensorflow.module import (
...@@ -770,9 +769,7 @@ class TransformerLayer(tf.keras.Model): # pylint: disable=too-few-public-methods ...@@ -770,9 +769,7 @@ class TransformerLayer(tf.keras.Model): # pylint: disable=too-few-public-methods
bias_dropout_add_func = get_bias_dropout_add(training) bias_dropout_add_func = get_bias_dropout_add(training)
# Bias dropout add. # Bias dropout add.
# The autocast scope is used to enforce the correct dtype for the bias. attention_bias = tf.cast(attention_bias, dtype=self.compute_dtype)
with autocast_variable.enable_auto_cast_variables(
self._compute_dtype_object):
if self.drop_path is None: if self.drop_path is None:
bda_output = bias_dropout_add_func( bda_output = bias_dropout_add_func(
attention_output, attention_output,
...@@ -809,11 +806,7 @@ class TransformerLayer(tf.keras.Model): # pylint: disable=too-few-public-methods ...@@ -809,11 +806,7 @@ class TransformerLayer(tf.keras.Model): # pylint: disable=too-few-public-methods
attention_output, attention_bias = inter_attention_outputs attention_output, attention_bias = inter_attention_outputs
residual = bda_output residual = bda_output
# The autocast scope is used to enforce the correct dtype for the attention_bias = tf.cast(attention_bias, dtype=self.compute_dtype)
# bias.
with autocast_variable.enable_auto_cast_variables(
self._compute_dtype_object
):
bda_output = bias_dropout_add_func( bda_output = bias_dropout_add_func(
attention_output, attention_output,
attention_bias, attention_bias,
...@@ -833,9 +826,7 @@ class TransformerLayer(tf.keras.Model): # pylint: disable=too-few-public-methods ...@@ -833,9 +826,7 @@ class TransformerLayer(tf.keras.Model): # pylint: disable=too-few-public-methods
residual = bda_output residual = bda_output
# Bias dropout add. # Bias dropout add.
# The autocast scope is used to enforce the correct dtype for the bias. mlp_bias = tf.cast(mlp_bias, dtype=self.compute_dtype)
with autocast_variable.enable_auto_cast_variables(
self._compute_dtype_object):
if self.drop_path is None: if self.drop_path is None:
output = bias_dropout_add_func( output = bias_dropout_add_func(
mlp_output, mlp_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment