Unverified Commit b7acb6e1 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

Add TensorFlow module and extensions (#85)



* Add tensorflow build

Improve build instructions

Fix pybind enum usage

Fix Python_EXECUTABLE cmake var

Move scale_inv calculations to FW
Signed-off-by: default avatarTrevor Morris <tmorris@nvidia.com>
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Apply clang-format
Signed-off-by: default avatarTrevor Morris <tmorris@nvidia.com>
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Format python files
Signed-off-by: default avatarTrevor Morris <tmorris@nvidia.com>
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Add TF build CI
Signed-off-by: default avatarTrevor Morris <tmorris@nvidia.com>
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Lint checks
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Another round of lint checks
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Fix TF image tag
Signed-off-by: default avatarTrevor Morris <tmorris@nvidia.com>
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Use the existing recipe file
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Add license claim blocks
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Fix a bug about bias dtype conversion
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Add mnist example and cleanup old examples
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Autopep8 the tests
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Autopep8 the examples
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Add example in Readme
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Add unit tests and linting for TensorFlow
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add causal mask for non-fused case
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Fix the mismatched TF vs TE masks
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Addressing CI tests
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Run lint test
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Add missing import
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Skip fp8 tests for pre-Hopper GPUs
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Remove non-pytest tests
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>

* Fix license
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTrevor Morris <tmorris@nvidia.com>
Signed-off-by: default avatarkaixih <kaixih@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarkaixih <kaixih@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 0963b288
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <pybind11/pybind11.h>
#include <string>
#include "common/include/transformer_engine/activation.h"
#include "common/include/transformer_engine/cast.h"
#include "common/include/transformer_engine/gemm.h"
#include "common/include/transformer_engine/layer_norm.h"
#include "common/include/transformer_engine/softmax.h"
#include "common/include/transformer_engine/transformer_engine.h"
#include "common/include/transformer_engine/transpose.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/python/eager/pywrap_tensor.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
namespace transformer_engine {
// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8FwdTensors {
GEMM1_INPUT = 0,
GEMM1_WEIGHT = 1,
GEMM2_INPUT = 2,
GEMM2_WEIGHT = 3
};
// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8BwdTensors { GRAD_OUTPUT1 = 0, GRAD_OUTPUT2 = 1 };
} // namespace transformer_engine
namespace {
void CheckTensorIsOnGPU(TFE_TensorHandle* tensor, TF_Status* status) {
const char* device_type = TFE_TensorHandleDeviceType(tensor, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK_EQ(std::string(device_type), std::string("GPU"))
<< "Tensor must be on the GPU, but got device_type=" << device_type;
}
std::vector<size_t> TensorShapeAsVector(TFE_TensorHandle* tensor,
TF_Status* status) {
std::vector<size_t> shape(TFE_TensorHandleNumDims(tensor, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(tensor, i, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
return shape;
}
transformer_engine::DType GetNVTEDataType(TF_DataType t) {
switch (t) {
case TF_HALF:
return transformer_engine::DType::kFloat16;
case TF_FLOAT:
return transformer_engine::DType::kFloat32;
case TF_BFLOAT16:
return transformer_engine::DType::kBFloat16;
case TF_BOOL:
case TF_INT8:
return transformer_engine::DType::kByte;
case TF_INT32:
return transformer_engine::DType::kInt32;
default:
CHECK(false) << "TF dtype is not supported: " << t;
}
}
TF_DataType GetTFDataType(transformer_engine::DType t) {
switch (t) {
case transformer_engine::DType::kFloat16:
return TF_HALF;
case transformer_engine::DType::kFloat32:
return TF_FLOAT;
case transformer_engine::DType::kBFloat16:
return TF_BFLOAT16;
case transformer_engine::DType::kByte:
case transformer_engine::DType::kFloat8E4M3:
case transformer_engine::DType::kFloat8E5M2:
return TF_INT8;
case transformer_engine::DType::kInt32:
return TF_INT32;
default:
CHECK(false) << "NVTE dtype is not supported: " << static_cast<int>(t);
}
}
void* TFE_TensorHandleDevicePointerNoSync(TFE_TensorHandle* h,
TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
tensorflow::unwrap(h);
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
return tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
unwrapped_handle)
->DevicePointer();
}
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
if (!tensorflow::TensorHandle::classof(unwrapped_handle)) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
tensorflow::TensorHandle* handle =
tensorflow::TensorHandleFromInterface(unwrapped_handle);
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
status->status = tensorflow::errors::InvalidArgument(
"TFE_TensorHandleDevicePointer may not be called on a ",
handle->TypeString(), " tensor handle.");
return nullptr;
}
const tensorflow::Tensor* tensor;
status->status = handle->Tensor(&tensor);
if (!status->status.ok()) {
return nullptr;
}
return const_cast<void*>(
static_cast<const void*>(tensor->tensor_data().data()));
}
// We assume the dptr is float when applying the offset. The offset is only
// meaningful for the amax/scale/scale_inv tensors.
void* GetDevicePtr(const pybind11::handle& handle, int offset = 0) {
if (offset == -1) return nullptr;
CHECK(EagerTensor_CheckExact(handle.ptr())) << "EagerTensor required!";
auto in_eager = EagerTensor_Handle(handle.ptr());
auto status = TF_NewStatus();
CheckTensorIsOnGPU(in_eager, status);
void* in_dptr = nullptr;
if (in_eager) {
in_dptr = TFE_TensorHandleDevicePointerNoSync(in_eager, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
TF_DeleteStatus(status);
return reinterpret_cast<float*>(in_dptr) + offset;
}
std::vector<size_t> GetShape(const pybind11::handle& handle) {
TFE_TensorHandle* in_eager = EagerTensor_Handle(handle.ptr());
TF_Status* status = TF_NewStatus();
std::vector<size_t> shape = TensorShapeAsVector(in_eager, status);
TF_DeleteStatus(status);
return shape;
}
transformer_engine::DType GetDataType(const pybind11::handle& handle) {
TFE_TensorHandle* in_eager = EagerTensor_Handle(handle.ptr());
auto tf_itype = TFE_TensorHandleDataType(in_eager);
return GetNVTEDataType(tf_itype);
}
transformer_engine::TensorWrapper MakeNVTETensor(
void* data_ptr, const std::vector<size_t>& shape,
const transformer_engine::DType type, void* amax_ptr = nullptr,
void* scale_ptr = nullptr, void* scale_inv_ptr = nullptr) {
return transformer_engine::TensorWrapper(
data_ptr, shape, type, reinterpret_cast<float*>(amax_ptr),
reinterpret_cast<float*>(scale_ptr),
reinterpret_cast<float*>(scale_inv_ptr));
}
tensorflow::Allocator* GetAllocator() {
static tensorflow::Allocator* allocator = nullptr;
if (allocator == nullptr) {
tensorflow::GPUOptions gpu_options;
tsl::TfDeviceId device_id(0);
allocator = tensorflow::GPUProcessState::singleton()->GetGPUAllocator(
gpu_options, device_id, /*total_bytes=*/1, /*peer_gpu_ids=*/{});
}
return allocator;
}
TFE_Context* GetContext(TF_Status* status) {
// Cache TF context.
static TFE_Context* context = nullptr;
if (context == nullptr) {
TFE_ContextOptions* opts = TFE_NewContextOptions();
context = TFE_NewContext(opts, status);
}
return context;
}
void Deallocator(void* data, size_t unused, void* tensor_handle) {
GetAllocator()->DeallocateRaw(data);
}
void* AllocateSpace(const std::vector<size_t>& shape,
transformer_engine::DType te_dtype, cudaStream_t stream = 0,
bool init_to_zeros = false) {
auto dtype = GetTFDataType(te_dtype);
// Allocate GPU memory.
size_t num_bytes = TF_DataTypeSize(dtype);
for (int i = 0; i < shape.size(); ++i) num_bytes *= shape[i];
void* data = GetAllocator()->AllocateRaw(
tensorflow::Allocator::kAllocatorAlignment, num_bytes);
if (init_to_zeros) {
CHECK_EQ(cudaMemsetAsync(data, 0, num_bytes, stream), cudaSuccess);
}
return data;
}
TFE_TensorHandle* CreateTensor(void* data, const std::vector<size_t>& shape,
transformer_engine::DType te_dtype) {
auto dtype = GetTFDataType(te_dtype);
size_t num_bytes = TF_DataTypeSize(dtype);
for (int i = 0; i < shape.size(); ++i) num_bytes *= shape[i];
TF_Status* status = TF_NewStatus();
TFE_Context* ctx = GetContext(status);
// Get first GPU device name.
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int num_devices = TF_DeviceListCount(devices);
const char* device_name = nullptr;
for (int i = 0; i < num_devices; ++i) {
const char* name = TF_DeviceListName(devices, i, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
if (std::string(name).find("GPU") != std::string::npos) {
device_name = name;
break;
}
}
CHECK_NE(device_name, nullptr) << "No GPU device found.";
std::vector<int64_t> shape64(shape.size());
std::transform(shape.cbegin(), shape.cend(), shape64.begin(),
[](const size_t& v) { return static_cast<int64_t>(v); });
TFE_TensorHandle* tensor = TFE_NewTensorHandleFromDeviceMemory(
ctx, device_name, dtype, shape64.data(), shape64.size(), data, num_bytes,
&Deallocator, nullptr, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
return tensor;
}
void dispatch_cast_transpose_fusion(
void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* output_cast, // o
const std::vector<size_t>& output_cast_shape,
const transformer_engine::DType output_cast_type,
void* output_transpose, // o
const std::vector<size_t>& output_transpose_shape,
const transformer_engine::DType output_transpose_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type, cudaStream_t stream) {
auto input_cu = MakeNVTETensor(input, input_shape, input_type);
auto output_cast_cu = MakeNVTETensor(
output_cast, output_cast_shape, output_cast_type, amax, scale, scale_inv);
auto output_transpose_cu =
MakeNVTETensor(output_transpose, output_transpose_shape,
output_transpose_type, amax, scale, scale_inv);
nvte_cast_transpose(input_cu.data(), output_cast_cu.data(),
output_transpose_cu.data(), stream);
}
void dispatch_transpose(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* output, // o
const std::vector<size_t>& output_shape,
const transformer_engine::DType output_type,
cudaStream_t stream) {
auto input_cu = MakeNVTETensor(input, input_shape, input_type);
auto output_cu = MakeNVTETensor(output, output_shape, output_type);
nvte_transpose(input_cu.data(), output_cu.data(), stream);
}
void dispatch_bgrad_cast_transpose_fusion(
void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* cast_output, // o
const std::vector<size_t>& cast_output_shape,
const transformer_engine::DType cast_output_type,
void* transposed_output, // o
const std::vector<size_t>& transposed_output_shape,
const transformer_engine::DType transposed_output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* dbias, // o
const std::vector<size_t>& dbias_shape,
const transformer_engine::DType dbias_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type, cudaStream_t stream) {
auto input_cu = MakeNVTETensor(input, input_shape, input_type);
auto cast_output_cu = MakeNVTETensor(
cast_output, cast_output_shape, cast_output_type, amax, scale, scale_inv);
auto transposed_output_cu =
MakeNVTETensor(transposed_output, transposed_output_shape,
transposed_output_type, amax, scale, scale_inv);
auto dbias_cu = MakeNVTETensor(dbias, dbias_shape, dbias_type);
transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), stream);
// Fill workspace
auto w_s = workspace.shape();
std::vector<size_t> w_shape_vec{w_s.data, w_s.data + w_s.ndim};
void* workspace_ptr = AllocateSpace(w_shape_vec, workspace.dtype());
workspace = MakeNVTETensor(workspace_ptr, w_shape_vec, workspace.dtype());
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), stream);
}
void dispatch_layernorm(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* gamma, // i
const std::vector<size_t>& gamma_shape,
const transformer_engine::DType gamma_type,
void* beta, // i
const std::vector<size_t>& beta_shape,
const transformer_engine::DType beta_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
const float epsilon, // i
void* z, // o
const std::vector<size_t>& z_shape,
const transformer_engine::DType z_type,
void* mu, // o
const std::vector<size_t>& mu_shape,
const transformer_engine::DType mu_type,
void* rsigma, // o
const std::vector<size_t>& rsigma_shape,
const transformer_engine::DType rsigma_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type,
const int multiProcessorCount, cudaStream_t stream) {
auto input_cu = MakeNVTETensor(input, input_shape, input_type);
auto gamma_cu = MakeNVTETensor(gamma, gamma_shape, gamma_type);
auto beta_cu = MakeNVTETensor(beta, beta_shape, beta_type);
auto z_cu = MakeNVTETensor(z, z_shape, z_type, amax, scale, scale_inv);
auto mu_cu = MakeNVTETensor(mu, mu_shape, mu_type);
auto rsigma_cu = MakeNVTETensor(rsigma, rsigma_shape, rsigma_type);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), epsilon,
z_cu.data(), mu_cu.data(), rsigma_cu.data(), stream,
multiProcessorCount, workspace.data(), barrier.data());
// Fill workspace and barrier
auto w_s = workspace.shape();
auto b_s = barrier.shape();
std::vector<size_t> w_shape_vec{w_s.data, w_s.data + w_s.ndim};
std::vector<size_t> b_shape_vec{b_s.data, b_s.data + b_s.ndim};
void* workspace_ptr = AllocateSpace(w_shape_vec, workspace.dtype());
void* barrier_ptr = AllocateSpace(b_shape_vec, barrier.dtype(), stream, true);
workspace = MakeNVTETensor(workspace_ptr, w_shape_vec, workspace.dtype());
barrier = MakeNVTETensor(barrier_ptr, b_shape_vec, barrier.dtype());
// Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), epsilon,
z_cu.data(), mu_cu.data(), rsigma_cu.data(), stream,
multiProcessorCount, workspace.data(), barrier.data());
}
void dispatch_gelu(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* output, // o
const std::vector<size_t>& output_shape,
const transformer_engine::DType output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type,
cudaStream_t stream) {
auto input_cu = MakeNVTETensor(input, input_shape, input_type);
auto output_cu =
MakeNVTETensor(output, output_shape, output_type, amax, scale, scale_inv);
nvte_gelu(input_cu.data(), output_cu.data(), stream);
}
void dispatch_bgrad_dgelu_cast_transpose_fusion(
void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* gelu_input, // i
const std::vector<size_t>& gelu_input_shape,
const transformer_engine::DType gelu_input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* cast_output, // o
const std::vector<size_t>& cast_output_shape,
const transformer_engine::DType cast_output_type,
void* transposed_output, // o
const std::vector<size_t>& transposed_output_shape,
const transformer_engine::DType transposed_output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* dbias, // o
const std::vector<size_t>& dbias_shape,
const transformer_engine::DType dbias_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type, cudaStream_t stream) {
auto gelu_input_cu =
MakeNVTETensor(gelu_input, gelu_input_shape, gelu_input_type);
auto input_cu = MakeNVTETensor(input, input_shape, input_type);
auto cast_output_cu = MakeNVTETensor(
cast_output, cast_output_shape, cast_output_type, amax, scale, scale_inv);
auto transposed_output_cu =
MakeNVTETensor(transposed_output, transposed_output_shape,
transposed_output_type, amax, scale, scale_inv);
auto dbias_cu = MakeNVTETensor(dbias, dbias_shape, dbias_type);
transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias_dgelu(
input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(), workspace.data(), stream);
// Fill workspace
auto w_s = workspace.shape();
std::vector<size_t> w_shape_vec{w_s.data, w_s.data + w_s.ndim};
void* workspace_ptr = AllocateSpace(w_shape_vec, workspace.dtype());
workspace = MakeNVTETensor(workspace_ptr, w_shape_vec, workspace.dtype());
nvte_cast_transpose_dbias_dgelu(
input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(), workspace.data(), stream);
}
TFE_TensorHandle* GetTFETensorHandle(const pybind11::handle tensor) {
CHECK(EagerTensor_CheckExact(tensor.ptr()))
<< "All inputs must be EagerTensors.";
return EagerTensor_Handle(tensor.ptr());
}
int GetDeviceMultiProcessorCount() {
static int count = [] {
cudaDeviceProp properties;
// Get current device
int device = -1;
CHECK_EQ(cudaGetDevice(&device), cudaSuccess)
<< "Got invalid GPU" << device;
CHECK_EQ(cudaGetDeviceProperties(&properties, device), cudaSuccess);
return properties.multiProcessorCount;
}();
return count;
}
py::object TFE_Py_TeGemm_wrapper(
const pybind11::handle& a_mat, const pybind11::handle& a_scale_inv,
const transformer_engine::DType atype, const int a_offset,
const pybind11::handle& b_mat, const pybind11::handle& b_scale_inv,
const transformer_engine::DType btype, const int b_offset,
const pybind11::handle& workspace, const bool use_bias,
const pybind11::handle& bias, const bool use_gelu,
const pybind11::handle& gelu_input, const bool transa,
const bool transb, const bool grad, const bool accumulate,
const bool use_split_accumulate, const transformer_engine::DType otype,
const int64_t stream_id) {
using namespace transformer_engine;
std::vector<size_t> a_shape = GetShape(a_mat);
std::vector<size_t> b_shape = GetShape(b_mat);
CHECK_EQ(a_shape.size(), 2);
CHECK_EQ(b_shape.size(), 2);
std::vector<size_t> d_shape{transb ? b_shape[1] : b_shape[0],
transa ? a_shape[0] : a_shape[1]};
auto a_tensor =
MakeNVTETensor(GetDevicePtr(a_mat), a_shape, atype, nullptr,
nullptr, GetDevicePtr(a_scale_inv, a_offset));
auto b_tensor =
MakeNVTETensor(GetDevicePtr(b_mat), b_shape, btype, nullptr,
nullptr, GetDevicePtr(b_scale_inv, b_offset));
NVTEShape empty_shape;
TensorWrapper bias_tensor(nullptr, empty_shape, DType::kBFloat16);
if (use_bias) {
bias_tensor = MakeNVTETensor(GetDevicePtr(bias), GetShape(bias),
GetDataType(bias));
}
TensorWrapper gelu_input_tensor(nullptr, empty_shape, DType::kBFloat16);
void* gelu_input_ptr = nullptr;
if (use_gelu && !grad) {
gelu_input_ptr = AllocateSpace(d_shape, otype);
gelu_input_tensor = MakeNVTETensor(gelu_input_ptr, d_shape, otype);
} else if (use_gelu) {
gelu_input_tensor =
MakeNVTETensor(GetDevicePtr(gelu_input), GetShape(gelu_input),
GetDataType(gelu_input));
}
auto workspace_tensor =
MakeNVTETensor(GetDevicePtr(workspace), GetShape(workspace),
GetDataType(workspace));
void* d_ptr = AllocateSpace(d_shape, otype);
auto d_tensor = MakeNVTETensor(d_ptr, d_shape, otype);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
nvte_cublas_gemm(a_tensor.data(), b_tensor.data(), d_tensor.data(),
bias_tensor.data(), gelu_input_tensor.data(), transa,
transb, grad, workspace_tensor.data(), accumulate,
use_split_accumulate, stream);
auto d_eager = CreateTensor(d_ptr, d_shape, otype);
if (use_gelu && !grad) {
auto gelu_input_eager = CreateTensor(gelu_input_ptr, d_shape, otype);
PyObject* result(PyList_New(2));
PyList_SET_ITEM(result, 0, EagerTensorFromHandle(d_eager));
PyList_SET_ITEM(result, 1, EagerTensorFromHandle(gelu_input_eager));
return tensorflow::PyoOrThrow(result);
}
return tensorflow::PyoOrThrow(EagerTensorFromHandle(d_eager));
}
} // end namespace
PYBIND11_MODULE(transformer_engine_tensorflow, m) {
py::enum_<transformer_engine::DType>(m, "DType")
.value("kByte", transformer_engine::DType::kByte)
.value("kInt32", transformer_engine::DType::kInt32)
.value("kFloat32", transformer_engine::DType::kFloat32)
.value("kFloat16", transformer_engine::DType::kFloat16)
.value("kBFloat16", transformer_engine::DType::kBFloat16)
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3)
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2);
py::enum_<transformer_engine::FP8FwdTensors>(m, "FP8FwdTensors",
py::arithmetic())
.value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT)
.value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT)
.value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT)
.value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT);
py::enum_<transformer_engine::FP8BwdTensors>(m, "FP8BwdTensors",
py::arithmetic())
.value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1)
.value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2);
m.def("cast_to_fp8",
[](const pybind11::handle& input, const pybind11::handle& scale,
const transformer_engine::DType otype, const pybind11::handle& amax,
const pybind11::handle& scale_inv, const int offset,
const int64_t stream_id) {
std::vector<size_t> shape_c = GetShape(input);
CHECK_EQ(shape_c.size(), 2);
auto input_tensor =
MakeNVTETensor(GetDevicePtr(input), shape_c, GetDataType(input));
void* out_c_ptr = AllocateSpace(shape_c, otype);
auto output_tensor = MakeNVTETensor(
out_c_ptr, shape_c, otype, GetDevicePtr(amax, offset),
GetDevicePtr(scale, offset), GetDevicePtr(scale_inv, offset));
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream);
auto out_c_eager = CreateTensor(out_c_ptr, shape_c, otype);
return tensorflow::PyoOrThrow(EagerTensorFromHandle(out_c_eager));
});
m.def("cast_from_fp8", [](const pybind11::handle& input,
const pybind11::handle& scale_inv,
const transformer_engine::DType itype,
const transformer_engine::DType otype,
const int offset, const int64_t stream_id) {
std::vector<size_t> shape_c = GetShape(input);
CHECK_EQ(shape_c.size(), 2);
auto input_tensor =
MakeNVTETensor(GetDevicePtr(input), shape_c, itype, nullptr, nullptr,
GetDevicePtr(scale_inv, offset));
void* out_ptr = AllocateSpace(shape_c, otype);
auto output_tensor = MakeNVTETensor(out_ptr, shape_c, otype);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
auto out_eager = CreateTensor(out_ptr, shape_c, otype);
return tensorflow::PyoOrThrow(EagerTensorFromHandle(out_eager));
});
m.def("fp8_cast_transpose_fused",
[](const pybind11::handle& input, const pybind11::handle& scale,
const transformer_engine::DType otype, const pybind11::handle& amax,
const pybind11::handle& scale_inv, const int offset,
const int64_t stream_id) {
using namespace transformer_engine;
std::vector<size_t> shape_c = GetShape(input);
CHECK_EQ(shape_c.size(), 2);
std::vector<size_t> shape_t{shape_c[1], shape_c[0]};
void* out_c_ptr = AllocateSpace(shape_c, otype);
void* out_t_ptr = AllocateSpace(shape_t, otype);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
dispatch_cast_transpose_fusion(
GetDevicePtr(input), shape_c, GetDataType(input),
GetDevicePtr(scale, offset), {1}, DType::kFloat32, out_c_ptr,
shape_c, otype, out_t_ptr, shape_t, otype,
GetDevicePtr(amax, offset), {1}, DType::kFloat32,
GetDevicePtr(scale_inv, offset), {1}, DType::kFloat32, stream);
auto out_c_eager = CreateTensor(out_c_ptr, shape_c, otype);
auto out_t_eager = CreateTensor(out_t_ptr, shape_t, otype);
PyObject* result(PyList_New(2));
PyList_SET_ITEM(result, 0, EagerTensorFromHandle(out_c_eager));
PyList_SET_ITEM(result, 1, EagerTensorFromHandle(out_t_eager));
return tensorflow::PyoOrThrow(result);
});
m.def("fp8_transpose", [](const pybind11::handle& input,
transformer_engine::DType otype,
const int64_t stream_id) {
std::vector<size_t> shape_c = GetShape(input);
CHECK_EQ(shape_c.size(), 2);
std::vector<size_t> shape_t{shape_c[1], shape_c[0]};
void* out_t_ptr = AllocateSpace(shape_t, otype);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
dispatch_transpose(GetDevicePtr(input), shape_c, otype, out_t_ptr, shape_t,
otype, stream);
TFE_TensorHandle* out_t_eager = CreateTensor(out_t_ptr, shape_t, otype);
return tensorflow::PyoOrThrow(EagerTensorFromHandle(out_t_eager));
});
m.def("fp8_cast_transpose_bgrad_fused",
[](const pybind11::handle& grad_out, const pybind11::handle& scale,
const transformer_engine::DType otype, const pybind11::handle& amax,
const pybind11::handle& scale_inv, const int offset,
const int64_t stream_id) {
using namespace transformer_engine;
std::vector<size_t> shape_c = GetShape(grad_out);
CHECK_EQ(shape_c.size(), 2);
std::vector<size_t> shape_t{shape_c[1], shape_c[0]};
std::vector<size_t> shape_b{shape_c[1]};
auto itype = GetDataType(grad_out);
void* grad_bias_ptr = AllocateSpace(shape_b, itype);
void* grad_out_c_ptr = AllocateSpace(shape_c, otype);
void* grad_out_t_ptr = AllocateSpace(shape_t, otype);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
dispatch_bgrad_cast_transpose_fusion(
GetDevicePtr(grad_out), shape_c, itype,
GetDevicePtr(scale, offset), {1}, DType::kFloat32, grad_out_c_ptr,
shape_c, otype, grad_out_t_ptr, shape_t, otype,
GetDevicePtr(amax, offset), {1}, DType::kFloat32, grad_bias_ptr,
shape_b, itype, GetDevicePtr(scale_inv, offset), {1},
DType::kFloat32, stream);
auto grad_bias_eager = CreateTensor(grad_bias_ptr, shape_b, itype);
auto grad_out_c_eager = CreateTensor(grad_out_c_ptr, shape_c, otype);
auto grad_out_t_eager = CreateTensor(grad_out_t_ptr, shape_t, otype);
PyObject* result(PyList_New(3));
PyList_SET_ITEM(result, 0, EagerTensorFromHandle(grad_bias_eager));
PyList_SET_ITEM(result, 1, EagerTensorFromHandle(grad_out_c_eager));
PyList_SET_ITEM(result, 2, EagerTensorFromHandle(grad_out_t_eager));
return tensorflow::PyoOrThrow(result);
});
m.def(
"te_gemm",
[](const pybind11::handle& a_mat, const pybind11::handle& a_scale_inv,
const transformer_engine::DType atype, const int a_offset,
const pybind11::handle& b_mat, const pybind11::handle& b_scale_inv,
const transformer_engine::DType btype, const int b_offset,
const pybind11::handle& workspace, const bool use_bias,
const pybind11::handle& bias, const bool use_gelu,
const pybind11::handle& gelu_input, const bool transa,
const bool transb, const bool grad, const bool accumulate,
const bool use_split_accumulate, const transformer_engine::DType otype,
const int64_t stream_id) {
return TFE_Py_TeGemm_wrapper(a_mat, a_scale_inv, atype, a_offset, b_mat,
b_scale_inv, btype, b_offset, workspace,
use_bias, bias, use_gelu, gelu_input, transa,
transb, grad, accumulate,
use_split_accumulate, otype, stream_id);
});
m.def("layernorm_fwd",
[](const pybind11::handle& input, const pybind11::handle& gamma,
const pybind11::handle& beta, float eps, const int64_t stream_id) {
using namespace transformer_engine;
std::vector<size_t> shape_c = GetShape(input);
CHECK_EQ(shape_c.size(), 2);
std::vector<size_t> shape_g{shape_c[1]};
std::vector<size_t> shape_m{shape_c[0]};
auto itype = GetDataType(input);
auto mtype = DType::kFloat32;
void* ln_out_ptr = AllocateSpace(shape_c, itype);
void* mu_ptr = AllocateSpace(shape_m, mtype);
void* rsigma_ptr = AllocateSpace(shape_m, mtype);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
dispatch_layernorm(
GetDevicePtr(input), shape_c, itype, GetDevicePtr(gamma), shape_g,
itype, GetDevicePtr(beta), shape_g, itype, nullptr, {1}, mtype,
eps, ln_out_ptr, shape_c, itype, mu_ptr, shape_m, mtype,
rsigma_ptr, shape_m, mtype, nullptr, {1}, mtype, nullptr, {1},
mtype, GetDeviceMultiProcessorCount(), stream);
auto ln_out_eager = CreateTensor(ln_out_ptr, shape_c, itype);
auto mu_eager = CreateTensor(mu_ptr, shape_m, mtype);
auto rsigma_eager = CreateTensor(rsigma_ptr, shape_m, mtype);
PyObject* result(PyList_New(3));
PyList_SET_ITEM(result, 0, EagerTensorFromHandle(ln_out_eager));
PyList_SET_ITEM(result, 1, EagerTensorFromHandle(mu_eager));
PyList_SET_ITEM(result, 2, EagerTensorFromHandle(rsigma_eager));
return tensorflow::PyoOrThrow(result);
});
m.def("layernorm_fwd_fp8",
[](const pybind11::handle& input, const pybind11::handle& gamma,
const pybind11::handle& beta, float eps,
const pybind11::handle& scale, const transformer_engine::DType otype,
const pybind11::handle& amax, const pybind11::handle& scale_inv,
const int offset, const int64_t stream_id) {
using namespace transformer_engine;
std::vector<size_t> shape_c = GetShape(input);
CHECK_EQ(shape_c.size(), 2);
std::vector<size_t> shape_g{shape_c[1]};
std::vector<size_t> shape_m{shape_c[0]};
auto itype = GetDataType(input);
auto mtype = DType::kFloat32;
void* ln_out_ptr = AllocateSpace(shape_c, otype);
void* mu_ptr = AllocateSpace(shape_m, mtype);
void* rsigma_ptr = AllocateSpace(shape_m, mtype);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
dispatch_layernorm(
GetDevicePtr(input), shape_c, itype, GetDevicePtr(gamma), shape_g,
itype, GetDevicePtr(beta), shape_g, itype,
GetDevicePtr(scale, offset), {1}, DType::kFloat32, eps,
ln_out_ptr, shape_c, otype, mu_ptr, shape_m, mtype, rsigma_ptr,
shape_m, mtype, GetDevicePtr(amax, offset), {1}, DType::kFloat32,
GetDevicePtr(scale_inv, offset), {1}, DType::kFloat32,
GetDeviceMultiProcessorCount(), stream);
auto ln_out_eager = CreateTensor(ln_out_ptr, shape_c, otype);
auto mu_eager = CreateTensor(mu_ptr, shape_m, mtype);
auto rsigma_eager = CreateTensor(rsigma_ptr, shape_m, mtype);
PyObject* result(PyList_New(3));
PyList_SET_ITEM(result, 0, EagerTensorFromHandle(ln_out_eager));
PyList_SET_ITEM(result, 1, EagerTensorFromHandle(mu_eager));
PyList_SET_ITEM(result, 2, EagerTensorFromHandle(rsigma_eager));
return tensorflow::PyoOrThrow(result);
});
m.def("layernorm_bwd", [](const pybind11::handle& dz,
const pybind11::handle& x,
const pybind11::handle& mu,
const pybind11::handle& rsigma,
const pybind11::handle& gamma,
const int64_t stream_id) {
using namespace transformer_engine;
std::vector<size_t> shape_x = GetShape(x);
CHECK_EQ(shape_x.size(), 2);
std::vector<size_t> shape_g{shape_x[1]};
std::vector<size_t> shape_m{shape_x[0]};
auto xtype = GetDataType(x);
auto gtype = GetDataType(gamma);
auto mtype = GetDataType(mu);
void* dx_ptr = AllocateSpace(shape_x, xtype);
void* dgamma_ptr = AllocateSpace(shape_g, gtype);
void* dbeta_ptr = AllocateSpace(shape_g, gtype);
auto x_tensor = MakeNVTETensor(GetDevicePtr(x), shape_x, xtype);
auto gamma_tensor = MakeNVTETensor(GetDevicePtr(gamma), shape_g, gtype);
auto dz_tensor = MakeNVTETensor(GetDevicePtr(dz), shape_x, xtype);
auto mu_tensor = MakeNVTETensor(GetDevicePtr(mu), shape_m, mtype);
auto rsigma_tensor = MakeNVTETensor(GetDevicePtr(rsigma), shape_m, mtype);
auto dx_tensor = MakeNVTETensor(dx_ptr, shape_x, xtype);
auto dgamma_tensor = MakeNVTETensor(dgamma_ptr, shape_g, gtype);
auto dbeta_tensor = MakeNVTETensor(dbeta_ptr, shape_g, gtype);
TensorWrapper workspace, barrier, dgamma_part, dbeta_part;
// This call populates tensors with the required config.
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
nvte_layernorm_bwd(
dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), dx_tensor.data(),
dgamma_tensor.data(), dbeta_tensor.data(), dgamma_part.data(),
dbeta_part.data(), stream, GetDeviceMultiProcessorCount(),
workspace.data(), barrier.data());
// Alloc space for Tensors.
auto w_s = workspace.shape();
auto b_s = barrier.shape();
auto dg_s = dgamma_part.shape();
auto db_s = dbeta_part.shape();
std::vector<size_t> w_shape_vec{w_s.data, w_s.data + w_s.ndim};
std::vector<size_t> b_shape_vec{b_s.data, b_s.data + b_s.ndim};
std::vector<size_t> dg_shape_vec{dg_s.data, dg_s.data + dg_s.ndim};
std::vector<size_t> db_shape_vec{db_s.data, db_s.data + db_s.ndim};
void* workspace_ptr = AllocateSpace(w_shape_vec, workspace.dtype());
void* barrier_ptr =
AllocateSpace(b_shape_vec, barrier.dtype(), stream, true);
void* dgamma_part_ptr = AllocateSpace(dg_shape_vec, dgamma_part.dtype());
void* dbeta_part_ptr = AllocateSpace(db_shape_vec, dbeta_part.dtype());
workspace = MakeNVTETensor(workspace_ptr, w_shape_vec, workspace.dtype());
barrier = MakeNVTETensor(barrier_ptr, b_shape_vec, barrier.dtype());
dgamma_part =
MakeNVTETensor(dgamma_part_ptr, dg_shape_vec, dgamma_part.dtype());
dbeta_part =
MakeNVTETensor(dbeta_part_ptr, db_shape_vec, dbeta_part.dtype());
// Actual call to bwd kernel.
nvte_layernorm_bwd(
dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), dx_tensor.data(),
dgamma_tensor.data(), dbeta_tensor.data(), dgamma_part.data(),
dbeta_part.data(), stream, GetDeviceMultiProcessorCount(),
workspace.data(), barrier.data());
auto dx_eager = CreateTensor(dx_ptr, shape_x, xtype);
auto dgamma_eager = CreateTensor(dgamma_ptr, shape_g, gtype);
auto dbeta_eager = CreateTensor(dbeta_ptr, shape_g, gtype);
PyObject* result(PyList_New(3));
PyList_SET_ITEM(result, 0, EagerTensorFromHandle(dx_eager));
PyList_SET_ITEM(result, 1, EagerTensorFromHandle(dgamma_eager));
PyList_SET_ITEM(result, 2, EagerTensorFromHandle(dbeta_eager));
return tensorflow::PyoOrThrow(result);
});
m.def("te_gelu",
[](const pybind11::handle& input, const pybind11::handle& scale,
const transformer_engine::DType otype, const pybind11::handle& amax,
const pybind11::handle& scale_inv, const int offset,
const int64_t stream_id) {
using namespace transformer_engine;
std::vector<size_t> shape_c = GetShape(input);
CHECK_EQ(shape_c.size(), 2);
void* out_ptr = AllocateSpace(shape_c, otype);
auto itype = GetDataType(input);
void* scale_ptr = nullptr;
void* amax_ptr = nullptr;
void* scale_inv_ptr = nullptr;
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
scale_ptr = GetDevicePtr(scale, offset);
amax_ptr = GetDevicePtr(amax, offset);
scale_inv_ptr = GetDevicePtr(scale_inv, offset);
}
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
dispatch_gelu(GetDevicePtr(input), shape_c, itype, scale_ptr, {1},
DType::kFloat32, out_ptr, shape_c, otype, amax_ptr, {1},
DType::kFloat32, scale_inv_ptr, {1}, DType::kFloat32,
stream);
auto out_eager = CreateTensor(out_ptr, shape_c, otype);
return tensorflow::PyoOrThrow(EagerTensorFromHandle(out_eager));
});
m.def("fp8_fused_cast_transpose_bgrad_dgelu",
[](const pybind11::handle& grad_output,
const pybind11::handle& gelu_input, const pybind11::handle& scale,
const transformer_engine::DType otype, const pybind11::handle& amax,
const pybind11::handle& scale_inv, const int offset,
const int64_t stream_id) {
using namespace transformer_engine;
std::vector<size_t> shape_c = GetShape(grad_output);
CHECK_EQ(shape_c.size(), 2);
std::vector<size_t> shape_t{shape_c[1], shape_c[0]};
std::vector<size_t> shape_b{shape_c[1]};
auto itype = GetDataType(grad_output);
void* grad_bias_ptr = AllocateSpace(shape_b, itype);
void* dgelu_c_ptr = AllocateSpace(shape_c, otype);
void* dgelu_t_ptr = AllocateSpace(shape_t, otype);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
dispatch_bgrad_dgelu_cast_transpose_fusion(
GetDevicePtr(grad_output), shape_c, itype,
GetDevicePtr(gelu_input), shape_c, itype,
GetDevicePtr(scale, offset), {1}, DType::kFloat32, dgelu_c_ptr,
shape_c, otype, dgelu_t_ptr, shape_t, otype,
GetDevicePtr(amax, offset), {1}, DType::kFloat32, grad_bias_ptr,
shape_b, itype, GetDevicePtr(scale_inv, offset), {1},
DType::kFloat32, stream);
auto grad_bias_eager = CreateTensor(grad_bias_ptr, shape_b, itype);
auto dgelu_c_eager = CreateTensor(dgelu_c_ptr, shape_c, otype);
auto dgelu_t_eager = CreateTensor(dgelu_t_ptr, shape_t, otype);
PyObject* result(PyList_New(3));
PyList_SET_ITEM(result, 0, EagerTensorFromHandle(grad_bias_eager));
PyList_SET_ITEM(result, 1, EagerTensorFromHandle(dgelu_c_eager));
PyList_SET_ITEM(result, 2, EagerTensorFromHandle(dgelu_t_eager));
return tensorflow::PyoOrThrow(result);
});
m.def(
"scaled_upper_triang_masked_softmax_forward",
[](const pybind11::handle& input, const float scale_factor,
const int64_t stream_id) {
using namespace transformer_engine;
std::vector<size_t> shape_in = GetShape(input);
CHECK_EQ(shape_in.size(), 3);
auto itype = GetDataType(input);
CHECK(itype == DType::kFloat16 || itype == DType::kBFloat16);
const size_t attn_batches = shape_in[0];
const size_t seq_len = shape_in[1];
CHECK_LE(seq_len, 2048);
auto input_cu = MakeNVTETensor(GetDevicePtr(input), shape_in, itype);
void* softmax_ptr = AllocateSpace(shape_in, itype);
auto softmax_results_cu = MakeNVTETensor(softmax_ptr, shape_in, itype);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
nvte_scaled_upper_triang_masked_softmax_forward(
input_cu.data(), softmax_results_cu.data(), scale_factor, stream);
auto softmax_results_eager = CreateTensor(softmax_ptr, shape_in, itype);
return tensorflow::PyoOrThrow(
EagerTensorFromHandle(softmax_results_eager));
});
m.def("scaled_upper_triang_masked_softmax_backward",
[](const pybind11::handle& dy, const pybind11::handle& y,
const float scale_factor, const int64_t stream_id) {
using namespace transformer_engine;
std::vector<size_t> shape_dy = GetShape(dy);
std::vector<size_t> shape_y = GetShape(y);
CHECK_EQ(shape_dy.size(), 3);
CHECK_EQ(shape_y.size(), 3);
auto dytype = GetDataType(dy);
auto ytype = GetDataType(y);
CHECK(dytype == DType::kFloat16 || dytype == DType::kBFloat16);
CHECK(ytype == DType::kFloat16 || ytype == DType::kBFloat16);
CHECK_EQ(shape_dy[1], shape_dy[2]);
auto dy_cu = MakeNVTETensor(GetDevicePtr(dy), shape_dy, dytype);
auto y_cu = MakeNVTETensor(GetDevicePtr(y), shape_y, ytype);
void* dx_ptr = AllocateSpace(shape_dy, dytype);
auto dx_cu = MakeNVTETensor(dx_ptr, shape_dy, dytype);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
nvte_scaled_upper_triang_masked_softmax_backward(
dy_cu.data(), y_cu.data(), dx_cu.data(), scale_factor, stream);
auto dx_eager = CreateTensor(dx_ptr, shape_dy, dytype);
return tensorflow::PyoOrThrow(EagerTensorFromHandle(dx_eager));
});
m.def("scaled_masked_softmax_forward", [](const pybind11::handle& x,
const pybind11::handle& mask,
const float scale_factor,
const int64_t stream_id) {
using namespace transformer_engine;
std::vector<size_t> shape_x = GetShape(x);
std::vector<size_t> shape_m = GetShape(mask);
CHECK_EQ(shape_x.size(), 4) << "expected 4D tensor";
CHECK_EQ(shape_m.size(), 4) << "expected 4D tensor";
auto xtype = GetDataType(x);
auto mtype = GetDataType(mask);
CHECK(xtype == DType::kFloat16 || xtype == DType::kBFloat16)
<< "Only fp16 and bf16 are supported";
CHECK(mtype == DType::kByte) << "Only bool are supported for mask";
const size_t batches = shape_x[0];
const size_t pad_batches = shape_m[0];
const size_t attn_heads = shape_x[1];
const size_t query_seq_len = shape_x[2];
const size_t key_seq_len = shape_x[3];
CHECK_LE(key_seq_len, 4096);
CHECK_GT(query_seq_len, 1);
CHECK(pad_batches == 1 || pad_batches == batches);
CHECK_EQ(shape_m[1], 1);
CHECK(shape_m[2] == query_seq_len);
CHECK(shape_m[3] == key_seq_len);
void* softmax_ptr = AllocateSpace(shape_x, xtype);
auto softmax_results_cu = MakeNVTETensor(softmax_ptr, shape_x, xtype);
auto input_cu = MakeNVTETensor(GetDevicePtr(x), shape_x, xtype);
auto mask_cu = MakeNVTETensor(GetDevicePtr(mask), shape_m, mtype);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
nvte_scaled_masked_softmax_forward(input_cu.data(), mask_cu.data(),
softmax_results_cu.data(), scale_factor,
stream);
auto softmax_results_eager = CreateTensor(softmax_ptr, shape_x, xtype);
return tensorflow::PyoOrThrow(EagerTensorFromHandle(softmax_results_eager));
});
m.def("scaled_masked_softmax_backward",
[](const pybind11::handle& dy, const pybind11::handle& y,
const float scale_factor, const int64_t stream_id) {
using namespace transformer_engine;
std::vector<size_t> shape_dy = GetShape(dy);
std::vector<size_t> shape_y = GetShape(y);
CHECK_EQ(shape_dy.size(), 4) << "expected 4D tensor";
CHECK_EQ(shape_y.size(), 4) << "expected 4D tensor";
auto dytype = GetDataType(dy);
auto ytype = GetDataType(y);
CHECK(dytype == DType::kFloat16 || dytype == DType::kBFloat16)
<< "Only fp16 and bf16 are supported";
CHECK(ytype == DType::kFloat16 || ytype == DType::kBFloat16)
<< "Only fp16 and bf16 are supported";
auto dy_cu = MakeNVTETensor(GetDevicePtr(dy), shape_dy, dytype);
auto y_cu = MakeNVTETensor(GetDevicePtr(y), shape_y, ytype);
void* dx_ptr = AllocateSpace(shape_dy, dytype);
auto dx_cu = MakeNVTETensor(dx_ptr, shape_dy, dytype);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
nvte_scaled_masked_softmax_backward(
dy_cu.data(), y_cu.data(), dx_cu.data(), scale_factor, stream);
auto dx_eager = CreateTensor(dx_ptr, shape_dy, dytype);
return tensorflow::PyoOrThrow(EagerTensorFromHandle(dx_eager));
});
m.def("scaled_softmax_forward", [](const pybind11::handle& x,
const float scale_factor,
const int64_t stream_id) {
using namespace transformer_engine;
std::vector<size_t> shape_x = GetShape(x);
CHECK_EQ(shape_x.size(), 4) << "expected 4D tensor";
auto xtype = GetDataType(x);
CHECK(xtype == DType::kFloat16 || xtype == DType::kBFloat16)
<< "Only fp16 and bf16 are supported";
const size_t batches = shape_x[0];
const size_t attn_heads = shape_x[1];
const size_t query_seq_len = shape_x[2];
const size_t key_seq_len = shape_x[3];
CHECK_LE(key_seq_len, 4096);
CHECK_GT(query_seq_len, 1);
void* softmax_ptr = AllocateSpace(shape_x, xtype);
auto softmax_results_cu = MakeNVTETensor(softmax_ptr, shape_x, xtype);
auto input_cu = MakeNVTETensor(GetDevicePtr(x), shape_x, xtype);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(),
scale_factor, stream);
auto softmax_results_eager = CreateTensor(softmax_ptr, shape_x, xtype);
return tensorflow::PyoOrThrow(EagerTensorFromHandle(softmax_results_eager));
});
m.def("scaled_softmax_backward",
[](const pybind11::handle& dy, const pybind11::handle& y,
const float scale_factor, const int64_t stream_id) {
using namespace transformer_engine;
std::vector<size_t> shape_dy = GetShape(dy);
std::vector<size_t> shape_y = GetShape(y);
CHECK_EQ(shape_dy.size(), 4) << "expected 4D tensor";
CHECK_EQ(shape_y.size(), 4) << "expected 4D tensor";
auto dytype = GetDataType(dy);
auto ytype = GetDataType(y);
CHECK(dytype == DType::kFloat16 || dytype == DType::kBFloat16)
<< "Only fp16 and bf16 are supported";
CHECK(ytype == DType::kFloat16 || ytype == DType::kBFloat16)
<< "Only fp16 and bf16 are supported";
auto dy_cu = MakeNVTETensor(GetDevicePtr(dy), shape_dy, dytype);
auto y_cu = MakeNVTETensor(GetDevicePtr(y), shape_y, ytype);
void* dx_ptr = AllocateSpace(shape_dy, dytype);
auto dx_cu = MakeNVTETensor(dx_ptr, shape_dy, dytype);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
nvte_scaled_softmax_backward(dy_cu.data(), y_cu.data(), dx_cu.data(),
scale_factor, stream);
auto dx_eager = CreateTensor(dx_ptr, shape_dy, dytype);
return tensorflow::PyoOrThrow(EagerTensorFromHandle(dx_eager));
});
}
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
class GetStreamOp : public OpKernel {
public:
explicit GetStreamOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* ctx) override {
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output("stream_id", {1}, &output));
auto vec = output->vec<int64_t>();
se::Stream* stream = ctx->op_device_context()->stream();
auto gpu_stream = se::gpu::AsGpuStreamValue(stream);
vec(0) = static_cast<int64_t>(reinterpret_cast<uintptr_t>(gpu_stream));
}
};
REGISTER_OP("GetStream")
.Output("stream_id: int64")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP_NO_GRADIENT("GetStream");
REGISTER_KERNEL_BUILDER(
Name("GetStream").Device(DEVICE_GPU).HostMemory("stream_id"), GetStreamOp);
} // namespace tensorflow
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""FP8 utilies for TransformerEngine"""
from contextlib import contextmanager
from typing import Optional, Dict, Any
import tensorflow as tf
import transformer_engine_tensorflow as tex
from transformer_engine.common.recipe import DelayedScaling, Format
_FP8_ENABLED = False
_FP8_RECIPE = None
_FP8_DISTRIBUTED_GROUP = None
_IS_FIRST_FP8_MODULE = False
_FP8_AUTOCAST_COUNTER = 0
_FP8_CURRENT_CONTEXT_ID = 0
_FP8_AUTOCAST_DEPTH = 0
_global_fp8_buffer = {}
_amax_forward_global_reduce_func = lambda: None
_buffer_delete_key_fwd = None
_buffer_delete_key_bwd = None
def get_meta_tensor_key(forward: bool = True) -> str:
"""Returns scaling key in `fp8_meta`."""
if forward:
return "scaling_fwd"
return "scaling_bwd"
def get_autocast_key(forward: bool = True) -> str:
"""Returns module position key in `fp8_meta`."""
if forward:
return "autocast_id_fwd"
return "autocast_id_bwd"
def get_amax_buffer_key(fp8_meta: Dict[str, Any], forward: bool = True) -> str:
"""Return a key in `_global_fp8_buffer` for the AMAX storage."""
if forward:
return f"FWD_AMAX_{fp8_meta['autocast_id_fwd']}"
return f"BWD_AMAX_{fp8_meta['autocast_id_bwd']}"
def set_amax_buffer_key_deletion(
fp8_meta: Dict[str, Any], forward: bool = True
) -> None:
"""Delete this amax key from global buffer during autocast end."""
if get_autocast_key(forward=forward) not in fp8_meta:
return
global _buffer_delete_key_fwd, _buffer_delete_key_bwd
if forward:
_buffer_delete_key_fwd = get_amax_buffer_key(fp8_meta, forward=forward)
else:
_buffer_delete_key_bwd = get_amax_buffer_key(fp8_meta, forward=forward)
def get_default_fp8_recipe():
"""FP8 recipe if not provided by user
Margin = 0, interval = 1, E4M3
"""
return DelayedScaling()
@contextmanager
def fp8_autocast(
enabled: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
) -> None:
"""
Context manager for FP8 usage.
.. code-block:: python
with fp8_autocast(enabled=True):
out = model(inp)
.. note::
Support for FP8 in the Dense layer of Transformer Engine is currently
limited to tensors with shapes where both dimensions are divisible by 16.
In terms of the input to the full Transformer network, this typically
requires padding sequence length to be multiple of 16.
Parameters
----------
enabled: bool, default = `False`
whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = `None`
recipe used for FP8 training.
"""
global _FP8_ENABLED, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP, _FP8_AUTOCAST_DEPTH
global _IS_FIRST_FP8_MODULE, _FP8_AUTOCAST_COUNTER
global _global_fp8_buffer, _buffer_delete_key_fwd
fp8_state = (_FP8_ENABLED, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP)
try:
_FP8_ENABLED = enabled
_FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
if _FP8_AUTOCAST_DEPTH == 0:
_IS_FIRST_FP8_MODULE = True
_FP8_AUTOCAST_COUNTER += 1
_FP8_AUTOCAST_DEPTH += 1
yield
finally:
_FP8_ENABLED, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP = fp8_state
_IS_FIRST_FP8_MODULE = False
_FP8_AUTOCAST_DEPTH -= 1
if _FP8_AUTOCAST_DEPTH == 0:
if callable(_amax_forward_global_reduce_func):
_amax_forward_global_reduce_func()
delete_key_from_amax_buffer(forward=True)
def get_fp8_context_id() -> int:
"""Returns an ID for the current FP8 context."""
return _FP8_CURRENT_CONTEXT_ID
def set_fp8_context_id(ctx_id: int) -> None:
"""Sets the current FP8 context."""
global _FP8_CURRENT_CONTEXT_ID
_FP8_CURRENT_CONTEXT_ID = ctx_id
def new_fp8_context_id() -> int:
"""Returns global autocast counter as a proxy to be used
as the autocast ID for FP8 modules.
"""
return _FP8_AUTOCAST_COUNTER
def is_fp8_enabled():
"""Is FP8 enabled"""
return _FP8_ENABLED
def is_first_fp8_module():
"""Returns `True` only the first time when called multiple
times from within the same `fp8_autocast` context.
"""
global _IS_FIRST_FP8_MODULE
tmp = _IS_FIRST_FP8_MODULE
_IS_FIRST_FP8_MODULE = False
return tmp
def get_fp8_recipe():
"""Return the fp8 recipe"""
return _FP8_RECIPE
def _default_sf_compute(amax, scale, fp8_max, margin):
"""Default function to convert amax to scaling factor."""
exp = tf.math.floor(tf.experimental.numpy.log2(fp8_max / amax)) - margin
sf = tf.math.round(tf.math.pow(2.0, tf.math.abs(exp)))
sf = tf.where(amax > 0.0, sf, scale)
sf = tf.where(tf.math.is_finite(amax), sf, scale)
sf = tf.where(exp < 0, 1.0 / sf, sf)
return sf
def _roll_and_zero_out(amax_history):
"""Update amax history and set next amax to zero."""
amax_history = tf.roll(amax_history, -1, 0)
zeros = tf.zeros(shape=amax_history[0].shape)
updated = tf.tensor_scatter_nd_update(amax_history, [[0]], [zeros])
return updated
@tf.function(jit_compile=True)
def _reduce_max_and_default_sf_compute(amax_history, scale, fp8_max, margin):
"""Get amax using max algorithm and compute scaling factor."""
amax = tf.reduce_max(amax_history, axis=0)
sf = _default_sf_compute(amax, scale, fp8_max, margin)
updated = _roll_and_zero_out(amax_history)
return updated, sf
@tf.function(jit_compile=True)
def _most_recent_and_default_sf_compute(amax_history, scale, fp8_max, margin):
"""Get amax using most-recent algorithm and compute scaling factor."""
amax = amax_history[0]
sf = _default_sf_compute(amax, scale, fp8_max, margin)
updated = _roll_and_zero_out(amax_history)
return updated, sf
def fused_amax_and_scale_update(
amax_history: tf.Variable,
scale: tf.Variable,
scale_inv: tf.Variable,
fp8_max: float,
margin: int,
amax_compute_algo: str,
):
"""Amax to scale conversion."""
if amax_compute_algo == "max":
updated, sf = _reduce_max_and_default_sf_compute(
amax_history, scale, fp8_max, margin
)
else:
assert amax_compute_algo == "most_recent"
updated, sf = _most_recent_and_default_sf_compute(
amax_history, scale, fp8_max, margin
)
amax_history.assign(updated)
scale.assign(sf)
scale_inv.assign(1.0 / sf)
def amax_and_scale_update(
fp8_meta: Dict[str, Any],
fwd_update: bool,
) -> None:
"""Updates fp8 amaxes/scales for fwd | bwd."""
amax_compute = fp8_meta["recipe"].amax_compute_algo
sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo
fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd"
fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd"
if not callable(amax_compute) and sf_compute is None:
fused_amax_and_scale_update(
fp8_meta[fp8_meta_tensor_key]["amax_history"],
fp8_meta[fp8_meta_tensor_key]["scale"],
fp8_meta[fp8_meta_tensor_key]["scale_inv"],
fp8_meta[fp8_max_key],
fp8_meta["recipe"].margin,
fp8_meta["recipe"].amax_compute_algo,
)
else:
raise ValueError(
"We only support the fp8 recipe with 'max' or 'most_recent' "
"amax_compute_algo and default scaling_factor_compute_algo at this "
"moment."
)
def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True):
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return tex.DType.kFloat8E4M3
return tex.DType.kFloat8E5M2
def delete_key_from_amax_buffer(forward: bool = True) -> None:
"""Delete the key from global amax buffer."""
global _global_fp8_buffer, _buffer_delete_key_fwd, _buffer_delete_key_bwd
if forward:
if (
_buffer_delete_key_fwd is not None
and _buffer_delete_key_fwd in _global_fp8_buffer
):
del _global_fp8_buffer[_buffer_delete_key_fwd]
else:
if (
_buffer_delete_key_bwd is not None
and _buffer_delete_key_bwd in _global_fp8_buffer
):
del _global_fp8_buffer[_buffer_delete_key_bwd]
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""XLA functions and JIT utilities"""
from typing import Callable
import tensorflow as tf
@tf.function(jit_compile=True)
def _bgrad_dgelu_fused(grad_output, inp):
"""Bgrad-Dgelu fused"""
x = inp
tanh_out = tf.math.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * (
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
) + 0.5 * (1 + tanh_out)
dgelu = ff * grad_output
bgrad = tf.math.reduce_sum(dgelu, axis=0)
return bgrad, dgelu
def bgrad_dgelu_fused(grad_output, inp):
"""Bgrad-Dgelu fused"""
return _bgrad_dgelu_fused(grad_output, inp)
def bias_dropout_add(
x: tf.Tensor,
bias: tf.Variable,
residual: tf.Tensor,
prob: float,
training: bool,
) -> tf.Tensor:
"""dropout(inp + bias) + residual"""
# TODO(kaixih): Use stateless_dropout and specify the seed mainly for
# debugging purpose. Should allow random seed.
out = (
tf.nn.experimental.stateless_dropout(
x + bias,
rate=prob,
seed=[1, 0],
)
if training
else x + bias
)
out = residual + out
return out
def get_bias_dropout_add(training: bool) -> Callable:
"""bias_dropout_add based on training or not"""
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
@tf.function(jit_compile=True)
def bias_dropout_add_fused_train_(
x: tf.Tensor,
bias: tf.Variable,
residual: tf.Tensor,
prob: float,
) -> tf.Tensor:
"""Jit fused bias_dropout_add for training"""
return bias_dropout_add(x, bias, residual, prob, True)
def bias_dropout_add_fused_train(
x: tf.Tensor,
bias: tf.Variable,
residual: tf.Tensor,
prob: float,
) -> tf.Tensor:
"""Jit fused bias_dropout_add for training"""
return bias_dropout_add_fused_train_(x, bias, residual, prob)
@tf.function(jit_compile=True)
def bias_dropout_add_fused_inference_(
x: tf.Tensor,
bias: tf.Variable,
residual: tf.Tensor,
prob: float,
) -> tf.Tensor:
"""Jit fused bias_dropout_add for inference"""
return bias_dropout_add(x, bias, residual, prob, False)
def bias_dropout_add_fused_inference(
x: tf.Tensor,
bias: tf.Variable,
residual: tf.Tensor,
prob: float,
) -> tf.Tensor:
"""Jit fused bias_dropout_add for inference"""
return bias_dropout_add_fused_inference_(x, bias, residual, prob)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Top level Transformer Engine PyTorch modules"""
from typing import Union, Callable
from keras import backend, layers, initializers
from keras.mixed_precision import autocast_variable
import tensorflow as tf
import transformer_engine_tensorflow as tex
from .constants import TE_DType
from .fp8 import (
is_fp8_enabled,
get_fp8_recipe,
get_default_fp8_recipe,
get_fp8_te_dtype,
is_first_fp8_module,
new_fp8_context_id,
get_fp8_context_id,
set_fp8_context_id,
amax_and_scale_update,
set_amax_buffer_key_deletion,
get_meta_tensor_key,
)
from .jit import (
bgrad_dgelu_fused,
)
stream_lib = tf.load_op_library(
tf.compat.v1.resource_loader.get_path_to_datafile(
tf.sysconfig.get_lib() + "/../lib_get_stream.so"
)
)
def get_stream_id():
"""Get stream index for GPU tasks."""
return stream_lib.get_stream().numpy()[0]
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_cublas_workspace = None
def get_workspace():
"""Returns workspace for cublas."""
global _cublas_workspace
if _cublas_workspace is None:
_cublas_workspace = tf.zeros([33_554_432], dtype=tf.int8)
return _cublas_workspace
def get_autocast_bias(dtype, bias_var, use_bias, use_fp8):
"""Get casted bias for fp8 gemm."""
if not use_bias:
return None
with autocast_variable.enable_auto_cast_variables(dtype):
bias = bias_var.value()
if use_fp8 and bias.dtype == tf.float32:
bias = tf.cast(bias, dtype=tf.bfloat16)
return bias
def get_init_method(user_input, default_init_method):
"""Get initializer method for variables."""
if user_input is None:
return default_init_method
if callable(user_input):
return user_input
assert isinstance(user_input, str)
return initializers.get(user_input)
def cast_to_fp8_wrapper(x, fp8_meta, amax_index, fwd, output_dtype, stream_id):
"""Wrapper to call the tex.cast_to_fp8."""
scaling_key = get_meta_tensor_key(fwd)
scale = fp8_meta[scaling_key]["scale"].value()
amax = fp8_meta[scaling_key]["amax_history"].value()
scale_inv = fp8_meta[scaling_key]["scale_inv"].value()
x_fp8 = tex.cast_to_fp8(
x, scale, output_dtype, amax, scale_inv, amax_index, stream_id
)
return x_fp8
def cast_from_fp8_wrapper(x, fp8_meta, amax_index, fwd, idtype, odtype, sid):
"""Wrapper to call the tex.cast_from_fp8."""
scaling_key = "scaling_fwd" if fwd else "scaling_bwd"
scale_inv = fp8_meta[scaling_key]["scale_inv"].value()
x_fp8 = tex.cast_from_fp8(x, scale_inv, idtype, odtype, amax_index, sid)
return x_fp8
def fp8_cast_transpose_fused_wrapper(x, fp8_meta, amax_index, fwd, output_dtype,
sid):
"""Wrapper to call the tex.fp8_cast_transpose_fused."""
scaling_key = get_meta_tensor_key(fwd)
scale = fp8_meta[scaling_key]["scale"].value()
amax = fp8_meta[scaling_key]["amax_history"].value()
scale_inv = fp8_meta[scaling_key]["scale_inv"].value()
x_fp8, x_t_fp8 = tex.fp8_cast_transpose_fused(
x, scale, output_dtype, amax, scale_inv, amax_index, sid
)
return x_fp8, x_t_fp8
def fp8_cast_transpose_bgrad_fused_wrapper(
x, fp8_meta, amax_index, fwd, output_dtype, sid
):
"""Wrapper to call the tex.fp8_cast_transpose_bgrad_fused."""
scaling_key = get_meta_tensor_key(fwd)
scale = fp8_meta[scaling_key]["scale"].value()
amax = fp8_meta[scaling_key]["amax_history"].value()
scale_inv = fp8_meta[scaling_key]["scale_inv"].value()
grad_bias, grad_fp8, grad_t_fp8 = tex.fp8_cast_transpose_bgrad_fused(
x, scale, output_dtype, amax, scale_inv, amax_index, sid
)
return grad_bias, grad_fp8, grad_t_fp8
def fp8_cast_transpose_bgrad_dgelu_fused_wrapper(
dy, x, fp8_meta, amax_index, fwd, output_dtype, sid
):
"""Wrapper to call the tex.fp8_fused_cast_transpose_bgrad_dgelu."""
scaling_key = get_meta_tensor_key(fwd)
scale = fp8_meta[scaling_key]["scale"].value()
amax = fp8_meta[scaling_key]["amax_history"].value()
scale_inv = fp8_meta[scaling_key]["scale_inv"].value()
dbias, dgelu_c, dgelu_t = tex.fp8_fused_cast_transpose_bgrad_dgelu(
dy, x, scale, output_dtype, amax, scale_inv, amax_index, sid
)
return dbias, dgelu_c, dgelu_t
def fp8_gelu_wrapper(x, fp8_meta, amax_index, fwd, output_dtype, sid):
"""Wrapper to call the tex.te_gelu."""
scaling_key = get_meta_tensor_key(fwd)
scale = fp8_meta[scaling_key]["scale"].value()
amax = fp8_meta[scaling_key]["amax_history"].value()
scale_inv = fp8_meta[scaling_key]["scale_inv"].value()
y_fp8 = tex.te_gelu(x, scale, output_dtype, amax, scale_inv, amax_index,
sid)
return y_fp8
def matmul_wrapper(
inp,
weight,
mode,
output_dtype,
sid,
use_bias=False,
bias=None,
grad=False,
gelu=False,
gelu_input=None,
):
"""Wrapper to call the tex.te_gemm for the non-fp8 gemm."""
A = inp
B = weight
A_dtype, B_dtype = TE_DType[A.dtype], TE_DType[B.dtype]
A_offset, B_offset = -1, -1
if mode in ("fwd", "fc1_fwd", "fc2_fwd"):
transA, transB = False, False
elif mode in ("bwd_input", "fc1_bwd_input", "fc2_bwd_input"):
transA, transB = False, True
elif mode in ("bwd_weight", "fc1_bwd_weight", "fc2_bwd_weight"):
transA, transB = True, False
return tex.te_gemm(
B,
None,
B_dtype,
B_offset,
A,
None,
A_dtype,
A_offset,
get_workspace(),
use_bias,
bias,
gelu,
gelu_input,
transB,
transA,
grad,
False, # accumulate
False, # accumulate
TE_DType[output_dtype],
sid,
)
def fp8_matmul_wrapper(
inp,
weight,
fp8_meta,
mode,
A_dtype,
B_dtype,
output_dtype,
use_split_accumulate,
sid,
use_bias=False,
bias=None,
):
"""Wrapper to call the tex.te_gemm for the fp8 gemm."""
A = inp
B = weight
if mode in ("fwd", "fc1_fwd"):
A_scale_inv = fp8_meta["scaling_fwd"]["scale_inv"].value()
A_offset = tex.FP8FwdTensors.GEMM1_INPUT
B_scale_inv = fp8_meta["scaling_fwd"]["scale_inv"].value()
B_offset = tex.FP8FwdTensors.GEMM1_WEIGHT
elif mode == "fc2_fwd":
A_scale_inv = fp8_meta["scaling_fwd"]["scale_inv"].value()
A_offset = tex.FP8FwdTensors.GEMM2_INPUT
B_scale_inv = fp8_meta["scaling_fwd"]["scale_inv"].value()
B_offset = tex.FP8FwdTensors.GEMM2_WEIGHT
elif mode == "bwd_input":
A_scale_inv = fp8_meta["scaling_bwd"]["scale_inv"].value()
A_offset = tex.FP8BwdTensors.GRAD_OUTPUT1
B_scale_inv = fp8_meta["scaling_fwd"]["scale_inv"].value()
B_offset = tex.FP8FwdTensors.GEMM1_WEIGHT
elif mode == "fc1_bwd_input":
A_scale_inv = fp8_meta["scaling_bwd"]["scale_inv"].value()
A_offset = tex.FP8BwdTensors.GRAD_OUTPUT2
B_scale_inv = fp8_meta["scaling_fwd"]["scale_inv"].value()
B_offset = tex.FP8FwdTensors.GEMM1_WEIGHT
elif mode == "fc2_bwd_input":
A_scale_inv = fp8_meta["scaling_bwd"]["scale_inv"].value()
A_offset = tex.FP8BwdTensors.GRAD_OUTPUT1
B_scale_inv = fp8_meta["scaling_fwd"]["scale_inv"].value()
B_offset = tex.FP8FwdTensors.GEMM2_WEIGHT
elif mode == "bwd_weight":
A_scale_inv = fp8_meta["scaling_fwd"]["scale_inv"].value()
A_offset = tex.FP8FwdTensors.GEMM1_INPUT
B_scale_inv = fp8_meta["scaling_bwd"]["scale_inv"].value()
B_offset = tex.FP8BwdTensors.GRAD_OUTPUT1
elif mode == "fc2_bwd_weight":
A_scale_inv = fp8_meta["scaling_fwd"]["scale_inv"].value()
A_offset = tex.FP8FwdTensors.GEMM2_INPUT
B_scale_inv = fp8_meta["scaling_bwd"]["scale_inv"].value()
B_offset = tex.FP8BwdTensors.GRAD_OUTPUT1
elif mode == "fc1_bwd_weight":
A_scale_inv = fp8_meta["scaling_fwd"]["scale_inv"].value()
A_offset = tex.FP8FwdTensors.GEMM1_INPUT
B_scale_inv = fp8_meta["scaling_bwd"]["scale_inv"].value()
B_offset = tex.FP8BwdTensors.GRAD_OUTPUT2
return tex.te_gemm(
B,
B_scale_inv,
B_dtype,
B_offset,
A,
A_scale_inv,
A_dtype,
A_offset,
get_workspace(),
use_bias,
bias,
False, # use_gelu
None, # gelu_input
True, # transa
False, # transb
False, # grad
False, # accumulate
use_split_accumulate,
TE_DType[output_dtype],
sid,
)
def layernorm_fwd_fp8_wrapper(
x, ln_gamma, ln_beta, epsilon, fp8_meta, amax_index, output_dtype, sid
):
"""Wrapper to call the tex.layernorm_fwd_fp8."""
scaling_key = "scaling_fwd"
scale = fp8_meta[scaling_key]["scale"].value()
amax = fp8_meta[scaling_key]["amax_history"].value()
scale_inv = fp8_meta[scaling_key]["scale_inv"].value()
ln_out, mu, rsigma = tex.layernorm_fwd_fp8(
x,
ln_gamma,
ln_beta,
epsilon,
scale,
output_dtype,
amax,
scale_inv,
amax_index,
sid,
)
return ln_out, mu, rsigma
# The DelayedScaling object is not supported in TF autograd. So, to avoid
# passing this object to the custom gradient function, we only extract the
# useful information.
def get_recipe_attrs(recipe):
"""Get attributes from the recipe."""
fp8_dtype_fwd = get_fp8_te_dtype(recipe, fprop_tensor=True)
fp8_dtype_bwd = get_fp8_te_dtype(recipe, fprop_tensor=False)
override_linear_precision = recipe.override_linear_precision
return (fp8_dtype_fwd, fp8_dtype_bwd, override_linear_precision)
# TransformerEngineBaseModule is a mixin class and its init function will pass
# through all the positional and keyword arguments to other subclasses. Make
# sure this class is inherited first.
class TransformerEngineBaseModule:
"""Base TE module."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# fp8 related
self.fp8 = False
self.fp8_meta = {}
self.fp8_meta["recipe"] = get_default_fp8_recipe()
self.fp8_meta_tensors_initialized = False
self.fp8_weight_shapes = []
self.stream_id = get_stream_id()
def set_meta_tensor(self, fwd):
"""Init scales and amaxes for fwd | bwd."""
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
num_fp8_tensors = (
self.fp8_meta["num_gemms"] * 2 if fwd else
self.fp8_meta["num_gemms"]
)
self.fp8_meta[fp8_meta_tensor_key] = {}
self.fp8_meta[fp8_meta_tensor_key]["scale"] = tf.Variable(
tf.ones((num_fp8_tensors), dtype=tf.float32), trainable=False
)
self.fp8_meta[fp8_meta_tensor_key]["scale_inv"] = tf.Variable(
tf.ones((num_fp8_tensors), dtype=tf.float32), trainable=False
)
self.fp8_meta[fp8_meta_tensor_key]["amax_history"] = tf.Variable(
tf.zeros(
(self.fp8_meta["recipe"].amax_history_len, num_fp8_tensors),
dtype=tf.float32,
),
trainable=False,
)
def init_fp8_meta_tensors(self):
"""Init scales and amaxes."""
# Checkpoint loaded
if self.fp8_meta_tensors_initialized:
return
self.set_meta_tensor(True)
self.set_meta_tensor(False)
def fp8_init(self, num_gemms=1):
"""Initialize fp8 related metadata and tensors during fprop."""
if not is_fp8_enabled():
self.fp8 = False
return
# FP8 is already enabled and recipe is the same, don't do anything.
if self.fp8 and get_fp8_recipe() == self.fp8_meta["recipe"]:
return
# Set FP8, recipe, and other FP8 metadata
self.fp8 = True
self.fp8_meta["recipe"] = get_fp8_recipe()
self.fp8_meta["num_gemms"] = num_gemms
# Set FP8_MAX per tensor according to recipe
fp8_format_val = self.fp8_meta["recipe"].fp8_format.value
self.fp8_meta["fp8_max_fwd"] = fp8_format_val.max_fwd
self.fp8_meta["fp8_max_bwd"] = fp8_format_val.max_bwd
# Allocate scales and amaxes
self.init_fp8_meta_tensors()
def pre_forward(self, training, num_gemms=1):
"""Checks and prep for FWD."""
self.fp8_init(num_gemms=num_gemms)
if self.fp8:
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
# Previous iteration was grad_enabled
amax_and_scale_update(self.fp8_meta, True)
set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
if training:
self.fp8_meta["first_module"] = is_first_fp8_module()
if self.fp8_meta["first_module"]:
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else:
self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id()
self.fp8_meta["update_amax_and_scale_fwd"] = True
# Create an empty tensor as a placeholder for the backprop to
# correctly know how many tensors to autograd.
self.fp8_meta["autocast_id_bwd"] = -1
else:
self.fp8_meta["update_amax_and_scale_fwd"] = False
def pre_backward(self):
"""Checks and prep for BWD."""
# From previous iteration
amax_and_scale_update(self.fp8_meta, False)
set_amax_buffer_key_deletion(self.fp8_meta, forward=False)
class Dense(TransformerEngineBaseModule, layers.Layer):
"""
Applies a linear transformation to the incoming data :math:`y = xW + b`
On NVIDIA GPUs it is a drop-in replacement for `tf.keras.layers.Dense`.
Parameters
----------
units : int
size of each output sample.
use_bias : bool, default = `True`
if set to `False`, the layer will not learn an additive bias.
kernel_initializer: Callable, default = `None`
used for initializing weights in the following way:
`kernel_initializer(weight)`. When set to `None`, defaults to
`tf.keras.initializers.RandomNormal(mean=0.0, std=0.023)`.
bias_initializer: Callable, default = `None`
used for initializing biases in the following way:
`bias_initializer(weight)`. When set to `None`, defaults to `zeros`.
Parallelism parameters
----------------------
skip_weight_param_allocation: bool, default = `False`
if set to `True`, weight parameter is not allocated and must be passed as
a keyword argument `weight` during the forward pass.
Optimization parameters
-----------------------
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself,
but instead return the bias value during the forward pass together with
the output of the linear transformation :math:`y = xW`. This is useful
when the bias addition can be fused to subsequent operations.
"""
def __init__(
self,
units: int,
use_bias: bool = True,
return_bias: bool = False,
kernel_initializer: Union[Callable, str, None] = None,
bias_initializer: Union[Callable, str, None] = None,
skip_weight_param_allocation: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.units = units
self.use_bias = use_bias
self.return_bias = return_bias
self.kernel_initializer = get_init_method(
kernel_initializer, initializers.RandomNormal(mean=0.0,
stddev=0.023)
)
self.bias_initializer = get_init_method(
bias_initializer, initializers.get("zeros")
)
self.skip_weight_param_allocation = skip_weight_param_allocation
def build(self, input_shape):
"""One-time allocation of the variables."""
input_shape = tf.TensorShape(input_shape)
last_dim = tf.compat.dimension_value(input_shape[-1])
if last_dim is None:
raise ValueError(
"The last dimension of the inputs to a Dense layer should be "
f"defined. Found None. Full input shape received: {input_shape}"
)
self.kernel = None
self.bias = None
if not self.skip_weight_param_allocation:
self.kernel = self.add_weight(
name="kernel",
shape=(last_dim, self.units),
initializer=self.kernel_initializer,
trainable=True,
)
if self.use_bias or self.return_bias:
self.bias = self.add_weight(
name="bias",
shape=(self.units,),
initializer=self.bias_initializer,
trainable=True,
)
# fp8 related
self.fp8_weight_shapes.append((last_dim, self.units))
self.built = True
def _get_training_value(self, training=None):
if training is None:
training = backend.learning_phase()
if isinstance(training, int):
training = bool(training)
if not self.trainable:
# When the layer is not trainable, it overrides the value passed
# from model.
training = False
return training
def non_fp8_matmul(
self,
inp: tf.Tensor,
kernel_var: tf.Variable,
bias_var: Union[tf.Variable, None] = None,
):
"""Prep fwd+bwd non-fp8 matmul."""
@tf.custom_gradient
def non_fp8_matmul_func(x):
# Use value() to convert from Variable to EagerTensor
kernel_val = kernel_var.value()
bias = get_autocast_bias(
self._compute_dtype_object, bias_var, self.use_bias,
use_fp8=False,
)
output_dtype = self._compute_dtype_object
outputs = matmul_wrapper(
x, kernel_val, "fwd", output_dtype, self.stream_id,
self.use_bias, bias,
)
def grad_fn(upstream, variables=None):
grad_x = matmul_wrapper(
upstream, kernel_val, "bwd_input", output_dtype,
self.stream_id,
)
grad_weight = matmul_wrapper(
x, upstream, "bwd_weight", output_dtype, self.stream_id
)
if self.use_bias:
grad_bias = tf.math.reduce_sum(upstream, axis=0)
grad_inputs = [grad_x]
grad_vars = []
for v in variables:
if v.name.endswith("bias:0") and self.use_bias:
grad_vars.append(grad_bias)
elif v.name.endswith("kernel:0"):
grad_vars.append(grad_weight)
return grad_inputs, grad_vars
return outputs, grad_fn
return non_fp8_matmul_func(inp)
def fp8_matmul(
self,
inp: tf.Tensor,
kernel_var: tf.Variable,
bias_var: Union[tf.Variable, None] = None,
):
"""Prep fwd+bwd fp8 matmul."""
fp8_meta = self.fp8_meta
fp8_dtype_fwd, fp8_dtype_bwd, override_linear_precision = \
get_recipe_attrs(fp8_meta["recipe"])
@tf.custom_gradient
def fp8_matmul_func(x):
# Use value() to convert from Variable to EagerTensor
kernel_val = kernel_var.value()
bias = get_autocast_bias(
self._compute_dtype_object, bias_var, self.use_bias,
use_fp8=True,
)
if not override_linear_precision.wgrad:
x_fp8, x_t_fp8 = fp8_cast_transpose_fused_wrapper(
x,
fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT,
True,
fp8_dtype_fwd,
self.stream_id,
)
else:
x_fp8 = cast_to_fp8_wrapper(
x,
fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT,
True,
fp8_dtype_fwd,
self.stream_id,
)
weight_fp8, weight_t_fp8 = fp8_cast_transpose_fused_wrapper(
kernel_val,
fp8_meta,
tex.FP8FwdTensors.GEMM1_WEIGHT,
True,
fp8_dtype_fwd,
self.stream_id,
)
output_dtype = self._compute_dtype_object
outputs = fp8_matmul_wrapper(
x_fp8,
weight_t_fp8,
fp8_meta,
"fwd",
fp8_dtype_fwd,
fp8_dtype_fwd,
output_dtype,
_2X_ACC_FPROP,
self.stream_id,
self.use_bias,
bias,
)
def grad_fn(upstream, variables=None):
self.pre_backward()
if self.use_bias:
(
grad_bias,
grad_fp8,
grad_t_fp8,
) = fp8_cast_transpose_bgrad_fused_wrapper(
upstream,
fp8_meta,
tex.FP8BwdTensors.GRAD_OUTPUT1,
False,
fp8_dtype_bwd,
self.stream_id,
)
else:
if not override_linear_precision.wgrad:
grad_fp8, grad_t_fp8 = fp8_cast_transpose_fused_wrapper(
upstream,
fp8_meta,
tex.FP8BwdTensors.GRAD_OUTPUT1,
False,
fp8_dtype_bwd,
self.stream_id,
)
else:
grad_fp8 = cast_to_fp8_wrapper(
upstream,
fp8_meta,
tex.FP8BwdTensors.GRAD_OUTPUT1,
False,
fp8_dtype_bwd,
self.stream_id,
)
grad_x = fp8_matmul_wrapper(
grad_fp8,
weight_fp8,
fp8_meta,
"bwd_input",
fp8_dtype_bwd,
fp8_dtype_fwd,
output_dtype,
_2X_ACC_DGRAD,
self.stream_id,
)
if not override_linear_precision.wgrad:
grad_weight = fp8_matmul_wrapper(
x_t_fp8,
grad_t_fp8,
fp8_meta,
"bwd_weight",
fp8_dtype_fwd,
fp8_dtype_bwd,
output_dtype,
_2X_ACC_WGRAD,
self.stream_id,
)
else:
grad_weight = matmul_wrapper(
x, upstream, "bwd_weight", output_dtype, self.stream_id
)
grad_inputs = [grad_x]
grad_vars = []
for v in variables:
if v.name.endswith("bias:0") and self.use_bias:
grad_vars.append(grad_bias)
elif v.name.endswith("kernel:0"):
grad_vars.append(grad_weight)
return grad_inputs, grad_vars
return outputs, grad_fn
return fp8_matmul_func(inp)
def call(
self,
inputs,
kernel=None,
bias=None,
training=None,
):
"""
Apply the linear transformation to the input.
Parameters
----------
inp : tf.Tensor
Input tensor.
weight : tf.Variable, default = None
An optional weight tensor for the module. This argument is compulsory
if module is initialized with `skip_weight_param_allocation=True`
bias : tf.Variable, default = None
An optional bias tensor for the module. This argument is compulsory if
module is initialized with `skip_weight_param_allocation=True` and one
of `use_bias` or `return_bias`
training : {True, False, None}, default = None
Whether this is in the training context.
"""
# self.pre_forward needs to be called outside the following branch,
# since it will set the self.fp8 if the autocast is detected.
training = self._get_training_value(training)
self.pre_forward(training)
kernel_var = (kernel if self.skip_weight_param_allocation else
self.kernel)
bias_var = bias if self.skip_weight_param_allocation else self.bias
if kernel_var is None:
raise ValueError("No valid kernel is provided")
inputmat = tf.reshape(inputs, shape=(-1, inputs.shape[-1]))
if self.fp8:
outputmat = self.fp8_matmul(inputmat, kernel_var, bias_var)
else:
outputmat = self.non_fp8_matmul(inputmat, kernel_var, bias_var)
outputs = tf.reshape(
outputmat, shape=(-1, *inputs.shape[1:-1], outputmat.shape[-1])
)
if self.return_bias:
return outputs, bias_var
return outputs
def get_config(self):
"""Returns the config of the layer."""
config = super().get_config()
config.update(
{
"units": self.units,
"use_bias": self.use_bias,
"kernel_initializer": initializers.serialize(
self.kernel_initializer),
"bias_initializer": initializers.serialize(
self.bias_initializer),
"skip_weight_param_allocation":
self.skip_weight_param_allocation,
}
)
class LayerNorm(layers.Layer):
"""
Applies Layer Normalization over a mini-batch of inputs.
Parameters
----------
epsilon : float, default = 1e-3
a value added to the denominator of layer normalization for numerical
stability.
gamma_initializer: Callable, default = `None`
used for initializing LayerNorm gamma in the following way:
`gamma_initializer(weight)`. When set to `None`, defaults to `ones`.
beta_initializer: Callable, default = `None`
used for initializing LayerNorm beta in the following way:
`beta_initializer(weight)`. When set to `None`, defaults to `zeros`.
"""
def __init__(
self, epsilon=1e-3, gamma_initializer="ones", beta_initializer="zeros",
**kwargs
):
super().__init__(**kwargs)
self.epsilon = epsilon
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.stream = get_stream_id()
def build(self, input_shape):
"""One-time allocation of the variables."""
input_shape = tf.TensorShape(input_shape)
last_dim = tf.compat.dimension_value(input_shape[-1])
if last_dim is None:
raise ValueError(
"The last dimension of the inputs to a Dense layer should be "
f"defined. Found None. Full input shape received: {input_shape}"
)
self.gamma = self.add_weight(
name="gamma",
shape=(last_dim,),
initializer=self.gamma_initializer,
trainable=True,
)
self.beta = self.add_weight(
name="beta",
shape=(last_dim,),
initializer=self.beta_initializer,
trainable=True,
)
self.built = True
@tf.custom_gradient
def layernorm(self, inp: tf.Tensor):
"""Prep fwd+bwd non-fp8 layernorm."""
gamma = self.gamma.value()
ln_out, mu, rsigma = tex.layernorm_fwd(
inp, gamma, self.beta.value(), self.epsilon, self.stream
)
def grad_fn(upstream, variables=None):
# pylint: disable=unused-argument
dxmat, dgamma, dbeta = tex.layernorm_bwd(
upstream, inp, mu, rsigma, gamma, self.stream
)
grad_inputs = [tf.reshape(dxmat, inp.shape)]
grad_vars = [dgamma, dbeta]
return grad_inputs, grad_vars
return ln_out, grad_fn
def call(self, inputs):
"""LayerNorm FWD"""
inputmat = tf.reshape(inputs, shape=(-1, inputs.shape[-1]))
outputmat = self.layernorm(inputmat)
outputs = tf.reshape(outputmat, shape=inputs.shape)
return outputs
def get_config(self):
"""Returns the config of the layer."""
config = super().get_config()
config.update(
{
"epsilon": self.epsilon,
"gamma_initializer": initializers.serialize(
self.gamma_initializer),
"beta_initializer": initializers.serialize(
self.beta_initializer),
}
)
class LayerNormDense(TransformerEngineBaseModule, layers.Layer):
"""
Applies layer normalization followed by linear transformation to the
incoming data.
Parameters
----------
units : int
size of each output sample.
epsilon : float, default = 1e-3
a value added to the denominator of layer normalization for numerical
stability.
use_bias : bool, default = `True`
if set to `False`, the layer will not learn an additive bias.
gamma_initializer: Callable, default = `None`
used for initializing LayerNorm gamma in the following way:
`gamma_initializer(weight)`. When set to `None`, defaults to `ones`.
beta_initializer: Callable, default = `None`
used for initializing LayerNorm beta in the following way:
`beta_initializer(weight)`. When set to `None`, defaults to `zeros`.
kernel_initializer : Callable, default = `None`
used for initializing GEMM weights in the following way:
`kernel_initializer(weight)`. When set to `None`, defaults to
`tf.keras.initializers.RandomNormal(mean=0.0, std=0.023)`.
bias_initializer : Callable, default = `None`
used for initializing GEMM bias in the following way:
`bias_initializer(weight)`. When set to `None`, defaults to `zeros`.
return_layernorm_output : bool, default = `False`
if set to `True`, output of layernorm is returned from the forward
together with the output of the linear transformation.
Example use case: residual connection for transformer module is taken post
layernorm.
Parallelism parameters
----------------------
skip_weight_param_allocation: bool, default = `False`
if set to `True`, weight parameter is not allocated and must be passed as
a keyword argument `weight` during the forward pass.
Optimization parameters
-----------------------
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself,
but instead return the bias value during the forward pass together with
the output of the linear transformation :math:`y = xW`. This is useful
when the bias addition can be fused to subsequent operations.
"""
def __init__(
self,
units,
epsilon=1e-3,
gamma_initializer: Union[Callable, str, None] = None,
beta_initializer: Union[Callable, str, None] = None,
return_layernorm_output=False,
use_bias=True,
return_bias=False,
kernel_initializer: Union[Callable, str, None] = None,
bias_initializer: Union[Callable, str, None] = None,
skip_weight_param_allocation=False,
**kwargs,
):
super().__init__(**kwargs)
self.units = units
self.epsilon = epsilon
self.gamma_initializer = get_init_method(
gamma_initializer, initializers.get("ones")
)
self.beta_initializer = get_init_method(
beta_initializer, initializers.get("zeros")
)
self.return_layernorm_output = return_layernorm_output
self.use_bias = use_bias
self.return_bias = return_bias
self.kernel_initializer = get_init_method(
kernel_initializer, initializers.RandomNormal(mean=0.0,
stddev=0.023)
)
self.bias_initializer = get_init_method(
bias_initializer, initializers.get("zeros")
)
self.skip_weight_param_allocation = skip_weight_param_allocation
def build(self, input_shape):
"""One-time allocation of the variables."""
input_shape = tf.TensorShape(input_shape)
last_dim = tf.compat.dimension_value(input_shape[-1])
if last_dim is None:
raise ValueError(
"The last dimension of the inputs to a Dense layer should be "
f"defined. Found None. Full input shape received: {input_shape}"
)
self.gamma = self.add_weight(
name="gamma",
shape=(last_dim,),
initializer=self.gamma_initializer,
trainable=True,
)
self.beta = self.add_weight(
name="beta",
shape=(last_dim,),
initializer=self.beta_initializer,
trainable=True,
)
self.kernel = None
self.bias = None
if not self.skip_weight_param_allocation:
self.kernel = self.add_weight(
name="kernel",
shape=(last_dim, self.units),
initializer=self.kernel_initializer,
trainable=True,
)
if self.use_bias or self.return_bias:
self.bias = self.add_weight(
name="bias",
shape=(self.units,),
initializer=self.bias_initializer,
trainable=True,
)
# fp8 related
self.fp8_weight_shapes.append((last_dim, self.units))
self.built = True
def _get_training_value(self, training=None):
if training is None:
training = backend.learning_phase()
if isinstance(training, int):
training = bool(training)
if not self.trainable:
# When the layer is not trainable, it overrides the value passed
# from model.
training = False
return training
def non_fp8_layernorm_matmul(
self,
inp: tf.Tensor,
gamma_var: tf.Variable,
beta_var: tf.Variable,
kernel_var: tf.Variable,
bias_var: Union[tf.Variable, None] = None,
):
"""Prep fwd+bwd non-fp8 layernorm followed by matmul."""
@tf.custom_gradient
def non_fp8_layernorm_matmul_func(x):
# Use value() to convert from Variable to EagerTensor
kernel_val = kernel_var.value()
gamma_val = gamma_var.value()
beta_val = beta_var.value()
ln_out, mu, rsigma = tex.layernorm_fwd(
x, gamma_val, beta_val, self.epsilon, self.stream_id
)
bias = get_autocast_bias(
self._compute_dtype_object, bias_var, self.use_bias,
use_fp8=False,
)
output_dtype = self._compute_dtype_object
outputs = matmul_wrapper(
ln_out,
kernel_val,
"fwd",
output_dtype,
self.stream_id,
self.use_bias,
bias,
)
def grad_fn(*upstream, variables=None):
grad_x = matmul_wrapper(
upstream[0], kernel_val, "bwd_input", output_dtype,
self.stream_id,
)
grad_weight = matmul_wrapper(
ln_out, upstream[0], "bwd_weight", output_dtype,
self.stream_id,
)
if self.use_bias:
grad_bias = tf.math.reduce_sum(upstream[0], axis=0)
if self.return_layernorm_output:
assert len(upstream) == 2
grad_x = grad_x + upstream[1]
dxmat, dgamma, dbeta = tex.layernorm_bwd(
grad_x, x, mu, rsigma, gamma_val, self.stream_id
)
grad_inputs = [dxmat]
grad_vars = []
for v in variables:
if v.name.endswith("gamma:0"):
grad_vars.append(dgamma)
elif v.name.endswith("bias:0") and self.use_bias:
grad_vars.append(grad_bias)
elif v.name.endswith("kernel:0"):
grad_vars.append(grad_weight)
elif v.name.endswith("beta:0"):
grad_vars.append(dbeta)
return grad_inputs, grad_vars
if self.return_layernorm_output:
return (outputs, ln_out), grad_fn
return outputs, grad_fn
return non_fp8_layernorm_matmul_func(inp)
def fp8_layernorm_matmul(
self,
inp: tf.Tensor,
gamma_var: tf.Variable,
beta_var: tf.Variable,
kernel_var: tf.Variable,
bias_var: Union[tf.Variable, None] = None,
):
"""Prep fwd+bwd fp8 layernorm followed by matmul."""
fp8_meta = self.fp8_meta
fp8_dtype_fwd, fp8_dtype_bwd, override_linear_precision = \
get_recipe_attrs(fp8_meta["recipe"])
@tf.custom_gradient
def fp8_layernorm_matmul_func(x):
# Use value() to convert from Variable to EagerTensor
kernel_val = kernel_var.value()
gamma_val = gamma_var.value()
beta_val = beta_var.value()
if not self.return_layernorm_output:
ln_out, mu, rsigma = layernorm_fwd_fp8_wrapper(
x,
gamma_val,
beta_val,
self.epsilon,
fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_fwd,
self.stream_id,
)
else:
ln_out_return, mu, rsigma = tex.layernorm_fwd(
x, gamma_val, beta_val, self.epsilon, self.stream_id
)
ln_out = cast_to_fp8_wrapper(
ln_out_return,
fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT,
True,
fp8_dtype_fwd,
self.stream_id,
)
bias = get_autocast_bias(
self._compute_dtype_object, bias_var, self.use_bias,
use_fp8=True,
)
weight_fp8, weight_t_fp8 = fp8_cast_transpose_fused_wrapper(
kernel_val,
fp8_meta,
tex.FP8FwdTensors.GEMM1_WEIGHT,
True,
fp8_dtype_fwd,
self.stream_id,
)
output_dtype = self._compute_dtype_object
outputs = fp8_matmul_wrapper(
ln_out,
weight_t_fp8,
fp8_meta,
"fwd",
fp8_dtype_fwd,
fp8_dtype_fwd,
output_dtype,
_2X_ACC_FPROP,
self.stream_id,
self.use_bias,
bias,
)
def grad_fn(*upstream, variables=None):
self.pre_backward()
if self.use_bias:
(
grad_bias,
grad_fp8,
grad_t_fp8,
) = fp8_cast_transpose_bgrad_fused_wrapper(
upstream[0],
fp8_meta,
tex.FP8BwdTensors.GRAD_OUTPUT1,
False,
fp8_dtype_bwd,
self.stream_id,
)
else:
if not override_linear_precision.wgrad:
grad_fp8, grad_t_fp8 = fp8_cast_transpose_fused_wrapper(
upstream[0],
fp8_meta,
tex.FP8BwdTensors.GRAD_OUTPUT1,
False,
fp8_dtype_bwd,
self.stream_id,
)
else:
grad_fp8 = cast_to_fp8_wrapper(
upstream[0],
fp8_meta,
tex.FP8BwdTensors.GRAD_OUTPUT1,
False,
fp8_dtype_bwd,
self.stream_id,
)
grad_x = fp8_matmul_wrapper(
grad_fp8,
weight_fp8,
fp8_meta,
"bwd_input",
fp8_dtype_bwd,
fp8_dtype_fwd,
output_dtype,
_2X_ACC_DGRAD,
self.stream_id,
)
if not override_linear_precision.wgrad:
ln_out_t = tex.fp8_transpose(ln_out, fp8_dtype_fwd,
self.stream_id)
grad_weight = fp8_matmul_wrapper(
ln_out_t,
grad_t_fp8,
fp8_meta,
"bwd_weight",
fp8_dtype_fwd,
fp8_dtype_bwd,
output_dtype,
_2X_ACC_WGRAD,
self.stream_id,
)
else:
ln_out_c = cast_from_fp8_wrapper(
ln_out,
fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT,
True,
fp8_dtype_fwd,
TE_DType[x.dtype],
self.stream_id,
)
grad_weight = matmul_wrapper(
ln_out_c,
upstream[0],
"bwd_weight",
output_dtype,
self.stream_id,
)
if self.return_layernorm_output:
assert len(upstream) == 2
grad_x = grad_x + upstream[1]
dxmat, dgamma, dbeta = tex.layernorm_bwd(
grad_x, x, mu, rsigma, gamma_val, self.stream_id
)
grad_inputs = [dxmat]
grad_vars = []
for v in variables:
if v.name.endswith("gamma:0"):
grad_vars.append(dgamma)
elif v.name.endswith("bias:0") and self.use_bias:
grad_vars.append(grad_bias)
elif v.name.endswith("kernel:0"):
grad_vars.append(grad_weight)
elif v.name.endswith("beta:0"):
grad_vars.append(dbeta)
return grad_inputs, grad_vars
if self.return_layernorm_output:
return (outputs, ln_out_return), grad_fn
return outputs, grad_fn
return fp8_layernorm_matmul_func(inp)
def call(
self,
inputs,
kernel=None,
bias=None,
training=None,
):
"""
Apply layer normalization to the input followed by a linear
transformation.
Parameters
----------
inputs : tf.Tensor
Input tensor.
kernel : tf.Variable, default = None
An optional weight tensor for the module. This argument is compulsory
if module is initialized with `skip_weight_param_allocation=True`
bias : tf.Variable, default = None
An optional bias tensor for the module. This argument is compulsory if
module is initialized with `skip_weight_param_allocation=True` and one
of `use_bias` or `return_bias`
training : {True, False, None}, default = None
Whether this is in the training context.
"""
# self.pre_forward needs to be called outside the following branch,
# since it has side effects to set the self.fp8 if the autocast is
# detected.
training = self._get_training_value(training)
self.pre_forward(training)
kernel_var = (kernel if self.skip_weight_param_allocation else
self.kernel)
bias_var = bias if self.skip_weight_param_allocation else self.bias
if kernel_var is None:
raise ValueError("No valid kernel is provided")
inputmat = tf.reshape(inputs, shape=(-1, inputs.shape[-1]))
if self.fp8:
outputs = self.fp8_layernorm_matmul(
inputmat, self.gamma, self.beta, kernel_var, bias_var
)
else:
outputs = self.non_fp8_layernorm_matmul(
inputmat, self.gamma, self.beta, kernel_var, bias_var
)
if self.return_layernorm_output:
outputmat, ln_outputmat = outputs
else:
outputmat = outputs
outputs = tf.reshape(
outputmat, shape=(-1, *inputs.shape[1:-1], outputmat.shape[-1])
)
if self.return_bias:
if self.return_layernorm_output:
ln_outputs = tf.reshape(ln_outputmat, shape=inputs.shape)
return (outputs, bias_var, ln_outputs)
return outputs, bias_var
if self.return_layernorm_output:
ln_outputs = tf.reshape(ln_outputmat, shape=inputs.shape)
return (outputs, ln_outputs)
return outputs
def get_config(self):
"""Returns the config of the layer."""
config = super().get_config()
config.update(
{
"units": self.units,
"epsilon": self.epsilon,
"gamma_initializer": initializers.serialize(
self.gamma_initializer),
"beta_initializer": initializers.serialize(
self.beta_initializer),
"return_layernorm_output": self.return_layernorm_output,
"use_bias": self.use_bias,
"kernel_initializer": initializers.serialize(
self.kernel_initializer),
"bias_initializer": initializers.serialize(
self.bias_initializer),
"skip_weight_param_allocation":
self.skip_weight_param_allocation,
}
)
class LayerNormMLP(TransformerEngineBaseModule, layers.Layer):
"""
Applies layer normalization on the input followed by the MLP module,
consisting of 2 successive linear transformations, separated by the GeLU
activation.
Parameters
----------
units : int
size of each input sample.
ffn_units : int
intermediate size to which input samples are projected.
epsilon : float, default = 1e-3
a value added to the denominator of layer normalization for numerical
stability.
gamma_initializer: Callable, default = `None`
used for initializing LayerNorm gamma in the following way:
`gamma_initializer(weight)`. When set to `None`, defaults to `ones`.
beta_initializer: Callable, default = `None`
used for initializing LayerNorm beta in the following way:
`beta_initializer(weight)`. When set to `None`, defaults to `zeros`.
use_bias : bool, default = `True`
if set to `False`, the FC2 layer will not learn an additive bias.
kernel_initializer: Callable, default = `None`
used for initializing FC1 weights in the following way:
`kernel_initializer(weight)`. When set to `None`, defaults to
`tf.keras.initializers.RandomNormal(mean=0.0, std=0.023)`.
ffn_kernel_initializer: Callable, default = `None`
used for initializing FC2 weights in the following way:
`ffn_kernel_initializer(weight)`. When set to `None`, defaults to
`tf.keras.initializers.RandomNormal(mean=0.0, std=0.023)`.
return_layernorm_output : bool, default = `False`
if set to `True`, output of layernorm is returned from the forward
together with the output of the linear transformation.
Example use case: residual connection for transformer module is taken post
layernorm.
bias_initializer: Callable, default = `None`
used for initializing FC1 and FC2 bias in the following way:
`bias_initializer(weight)`. When set to `None`, defaults to `zeros`.
Optimization parameters
-----------------------
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself,
but instead return the bias value during the forward pass together with
the output of the linear transformation :math:`y = xW`. This is useful
when the bias addition can be fused to subsequent operations.
"""
def __init__(
self,
units: int,
ffn_units: int,
epsilon: float = 1e-3,
gamma_initializer: Union[Callable, str, None] = None,
beta_initializer: Union[Callable, str, None] = None,
return_layernorm_output: bool = False,
use_bias: bool = True,
return_bias: bool = False,
kernel_initializer: Union[Callable, str, None] = None,
ffn_kernel_initializer: Union[Callable, str, None] = None,
bias_initializer: Union[Callable, str, None] = None,
**kwargs,
):
super().__init__(**kwargs)
self.fc1_units = units
self.fc2_units = ffn_units
self.epsilon = epsilon
self.gamma_initializer = get_init_method(
gamma_initializer, initializers.get("ones")
)
self.beta_initializer = get_init_method(
beta_initializer, initializers.get("zeros")
)
self.return_layernorm_output = return_layernorm_output
self.use_bias = use_bias
self.return_bias = return_bias
self.kernel1_initializer = get_init_method(
kernel_initializer, initializers.RandomNormal(mean=0.0,
stddev=0.023)
)
self.kernel2_initializer = get_init_method(
ffn_kernel_initializer, initializers.RandomNormal(mean=0.0,
stddev=0.023)
)
self.bias_initializer = get_init_method(
bias_initializer, initializers.get("zeros")
)
def build(self, input_shape):
"""One-time allocation of the variables."""
input_shape = tf.TensorShape(input_shape)
last_dim = tf.compat.dimension_value(input_shape[-1])
if last_dim is None:
raise ValueError(
"The last dimension of the inputs to a Dense layer should be "
f"defined. Found None. Full input shape received: {input_shape}"
)
self.gamma = self.add_weight(
name="gamma",
shape=(last_dim,),
initializer=self.gamma_initializer,
trainable=True,
)
self.beta = self.add_weight(
name="beta",
shape=(last_dim,),
initializer=self.beta_initializer,
trainable=True,
)
self.fc1_kernel = self.add_weight(
name="fc1_kernel",
shape=(last_dim, self.fc1_units),
initializer=self.kernel1_initializer,
trainable=True,
)
self.fc1_bias = self.add_weight(
name="fc1_bias",
shape=(self.fc1_units,),
initializer=self.bias_initializer,
trainable=True,
)
# fp8 related
self.fp8_weight_shapes.append((last_dim, self.fc1_units))
self.fc2_kernel = self.add_weight(
name="fc2_kernel",
shape=(self.fc1_units, self.fc2_units),
initializer=self.kernel2_initializer,
trainable=True,
)
self.fc2_bias = None
if self.use_bias or self.return_bias:
self.fc2_bias = self.add_weight(
name="fc2_bias",
shape=(self.fc2_units,),
initializer=self.bias_initializer,
trainable=True,
)
# fp8 related
self.fp8_weight_shapes.append((self.fc1_units, self.fc2_units))
self.built = True
def _get_training_value(self, training=None):
if training is None:
training = backend.learning_phase()
if isinstance(training, int):
training = bool(training)
if not self.trainable:
# When the layer is not trainable, it overrides the value passe from
# model.
training = False
return training
def non_fp8_layernorm_mlp(
self,
inp: tf.Tensor,
gamma_var: tf.Variable,
beta_var: tf.Variable,
fc1_kernel_var: tf.Variable,
fc1_bias_var: tf.Variable,
fc2_kernel_var: tf.Variable,
fc2_bias_var: Union[tf.Variable, None] = None,
):
"""Prep fwd+bwd non-fp8 layernorm followed by mlp."""
@tf.custom_gradient
def non_fp8_layernorm_mlp_func(x):
# Use value() to convert from Variable to EagerTensor
fc1_kernel_val = fc1_kernel_var.value()
fc2_kernel_val = fc2_kernel_var.value()
gamma_val = gamma_var.value()
beta_val = beta_var.value()
ln_out, mu, rsigma = tex.layernorm_fwd(
x, gamma_val, beta_val, self.epsilon, self.stream_id
)
fc1_bias = get_autocast_bias(
self._compute_dtype_object, fc1_bias_var, use_bias=True,
use_fp8=False,
)
fc2_bias = get_autocast_bias(
self._compute_dtype_object, fc2_bias_var, self.use_bias,
use_fp8=False,
)
output_dtype = self._compute_dtype_object
# TODO(kaixih): Ideally, we should set gelu=True to fuse the gelu in
# cuBlasLt calls. However, it seems it is slower than the unfused
# version. Fix this when cuBlasLt improves the issue.
fc1_out = matmul_wrapper(
ln_out,
fc1_kernel_val,
"fc1_fwd",
output_dtype,
self.stream_id,
use_bias=True,
bias=fc1_bias,
)
gelu_out = tex.te_gelu(
fc1_out, None, TE_DType[output_dtype], None, None, 0,
self.stream_id,
)
fc2_out = matmul_wrapper(
gelu_out,
fc2_kernel_val,
"fc2_fwd",
output_dtype,
self.stream_id,
use_bias=self.use_bias,
bias=fc2_bias,
)
def grad_fn(*upstream, variables=None):
fc2_dgrad = matmul_wrapper(
upstream[0],
fc2_kernel_val,
"fc2_bwd_input",
output_dtype,
self.stream_id,
grad=True,
gelu=True,
gelu_input=fc1_out,
)
fc2_wgrad = matmul_wrapper(
gelu_out, upstream[0], "bwd_weight", output_dtype,
self.stream_id,
)
if self.use_bias:
fc2_bias_grad = tf.math.reduce_sum(upstream[0], axis=0)
dgelu = fc2_dgrad
fc1_dgrad = matmul_wrapper(
dgelu, fc1_kernel_val, "fc1_bwd_input", output_dtype,
self.stream_id,
)
fc1_wgrad = matmul_wrapper(
ln_out, dgelu, "bwd_weight", output_dtype, self.stream_id
)
fc1_bias_grad = tf.math.reduce_sum(dgelu, axis=0)
d_ln_out = fc1_dgrad
if self.return_layernorm_output:
assert len(upstream) == 2
d_ln_out = d_ln_out + upstream[1]
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, x, mu, rsigma, gamma_val, self.stream_id
)
grad_inputs = [dxmat]
grad_vars = []
for v in variables:
if v.name.endswith("gamma:0"):
grad_vars.append(dgamma)
elif v.name.endswith("fc1_kernel:0"):
grad_vars.append(fc1_wgrad)
elif v.name.endswith("fc1_bias:0"):
grad_vars.append(fc1_bias_grad)
elif v.name.endswith("fc2_kernel:0"):
grad_vars.append(fc2_wgrad)
elif v.name.endswith("fc2_bias:0") and self.use_bias:
grad_vars.append(fc2_bias_grad)
elif v.name.endswith("beta:0"):
grad_vars.append(dbeta)
return grad_inputs, grad_vars
if self.return_layernorm_output:
return (fc2_out, ln_out), grad_fn
return fc2_out, grad_fn
return non_fp8_layernorm_mlp_func(inp)
def fp8_layernorm_mlp(
self,
inp: tf.Tensor,
gamma_var: tf.Variable,
beta_var: tf.Variable,
fc1_kernel_var: tf.Variable,
fc1_bias_var: tf.Variable,
fc2_kernel_var: tf.Variable,
fc2_bias_var: Union[tf.Variable, None] = None,
):
"""Prep fwd+bwd fp8 layernorm followed by mlp."""
fp8_meta = self.fp8_meta
fp8_dtype_fwd, fp8_dtype_bwd, override_linear_precision = \
get_recipe_attrs(fp8_meta["recipe"])
@tf.custom_gradient
def fp8_layernorm_mlp_func(x):
# Use value() to convert from Variable to EagerTensor
fc1_kernel_val = fc1_kernel_var.value()
fc2_kernel_val = fc2_kernel_var.value()
gamma_val = gamma_var.value()
beta_val = beta_var.value()
if not self.return_layernorm_output:
ln_out, mu, rsigma = layernorm_fwd_fp8_wrapper(
x,
gamma_val,
beta_val,
self.epsilon,
fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_fwd,
self.stream_id,
)
else:
ln_out_return, mu, rsigma = tex.layernorm_fwd(
x, gamma_val, beta_val, self.epsilon, self.stream_id
)
ln_out = cast_to_fp8_wrapper(
ln_out_return,
fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT,
True,
fp8_dtype_fwd,
self.stream_id,
)
fc1_bias = get_autocast_bias(
self._compute_dtype_object, fc1_bias_var, use_bias=True,
use_fp8=True,
)
fc2_bias = get_autocast_bias(
self._compute_dtype_object, fc2_bias_var, self.use_bias,
use_fp8=True,
)
fc1_weight_fp8, fc1_weight_t_fp8 = fp8_cast_transpose_fused_wrapper(
fc1_kernel_val,
fp8_meta,
tex.FP8FwdTensors.GEMM1_WEIGHT,
True,
fp8_dtype_fwd,
self.stream_id,
)
fc2_weight_fp8, fc2_weight_t_fp8 = fp8_cast_transpose_fused_wrapper(
fc2_kernel_val,
fp8_meta,
tex.FP8FwdTensors.GEMM2_WEIGHT,
True,
fp8_dtype_fwd,
self.stream_id,
)
output_dtype = self._compute_dtype_object
fc1_out = fp8_matmul_wrapper(
ln_out,
fc1_weight_t_fp8,
fp8_meta,
"fc1_fwd",
fp8_dtype_fwd,
fp8_dtype_fwd,
output_dtype,
_2X_ACC_FPROP,
self.stream_id,
use_bias=True,
bias=fc1_bias,
)
gelu_out = fp8_gelu_wrapper(
fc1_out,
fp8_meta,
tex.FP8FwdTensors.GEMM2_INPUT,
True,
fp8_dtype_fwd,
self.stream_id,
)
fc2_out = fp8_matmul_wrapper(
gelu_out,
fc2_weight_t_fp8,
fp8_meta,
"fc2_fwd",
fp8_dtype_fwd,
fp8_dtype_fwd,
output_dtype,
_2X_ACC_FPROP,
self.stream_id,
use_bias=self.use_bias,
bias=fc2_bias,
)
def grad_fn(*upstream, variables=None):
self.pre_backward()
if self.use_bias:
(
fc2_bias_grad,
grad_fp8,
grad_t_fp8,
) = fp8_cast_transpose_bgrad_fused_wrapper(
upstream[0],
fp8_meta,
tex.FP8BwdTensors.GRAD_OUTPUT1,
False,
fp8_dtype_bwd,
self.stream_id,
)
else:
if not override_linear_precision.wgrad:
grad_fp8, grad_t_fp8 = fp8_cast_transpose_fused_wrapper(
upstream[0],
fp8_meta,
tex.FP8BwdTensors.GRAD_OUTPUT1,
False,
fp8_dtype_bwd,
self.stream_id,
)
else:
grad_fp8 = cast_to_fp8_wrapper(
upstream[0],
fp8_meta,
tex.FP8BwdTensors.GRAD_OUTPUT1,
False,
fp8_dtype_bwd,
self.stream_id,
)
fc2_dgrad = fp8_matmul_wrapper(
grad_fp8,
fc2_weight_fp8,
fp8_meta,
"fc2_bwd_input",
fp8_dtype_bwd,
fp8_dtype_fwd,
output_dtype,
_2X_ACC_DGRAD,
self.stream_id,
)
if not override_linear_precision.wgrad:
gelu_out_t = tex.fp8_transpose(
gelu_out, fp8_dtype_fwd, self.stream_id
)
fc2_wgrad = fp8_matmul_wrapper(
gelu_out_t,
grad_t_fp8,
fp8_meta,
"fc2_bwd_weight",
fp8_dtype_fwd,
fp8_dtype_bwd,
output_dtype,
_2X_ACC_WGRAD,
self.stream_id,
)
(
fc1_bias_grad,
dgelu,
dgelu_t,
) = fp8_cast_transpose_bgrad_dgelu_fused_wrapper(
fc2_dgrad,
fc1_out,
fp8_meta,
tex.FP8BwdTensors.GRAD_OUTPUT2,
False,
fp8_dtype_bwd,
self.stream_id,
)
else:
gelu_out_c = cast_from_fp8_wrapper(
gelu_out,
fp8_meta,
tex.FP8FwdTensors.GEMM2_INPUT,
True,
fp8_dtype_fwd,
TE_DType[x.dtype],
self.stream_id,
)
fc2_wgrad = matmul_wrapper(
gelu_out_c,
upstream[0],
"bwd_weight",
output_dtype,
self.stream_id,
)
# Different from PyTorch implementation, the fc1_out has
# already added bias. So we don't need to pass fc1_bias
# here.
fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused(fc2_dgrad,
fc1_out)
dgelu = cast_to_fp8_wrapper(
dgelu_no_fp8,
fp8_meta,
tex.FP8BwdTensors.GRAD_OUTPUT2,
False,
fp8_dtype_bwd,
self.stream_id,
)
dgelu_t = None
fc1_dgrad = fp8_matmul_wrapper(
dgelu,
fc1_weight_fp8,
fp8_meta,
"fc1_bwd_input",
fp8_dtype_bwd,
fp8_dtype_fwd,
output_dtype,
_2X_ACC_DGRAD,
self.stream_id,
)
if not override_linear_precision.wgrad:
ln_out_t = tex.fp8_transpose(ln_out, fp8_dtype_fwd,
self.stream_id)
fc1_wgrad = fp8_matmul_wrapper(
ln_out_t,
dgelu_t,
fp8_meta,
"fc1_bwd_weight",
fp8_dtype_fwd,
fp8_dtype_bwd,
output_dtype,
_2X_ACC_WGRAD,
self.stream_id,
)
else:
ln_out_c = cast_from_fp8_wrapper(
ln_out,
fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT,
True,
fp8_dtype_fwd,
TE_DType[x.dtype],
self.stream_id,
)
fc1_wgrad = matmul_wrapper(
ln_out_c,
dgelu_no_fp8,
"bwd_weight",
output_dtype,
self.stream_id,
)
d_ln_out = fc1_dgrad
if self.return_layernorm_output:
assert len(upstream) == 2
d_ln_out = d_ln_out + upstream[1]
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, x, mu, rsigma, gamma_val, self.stream_id
)
grad_inputs = [dxmat]
grad_vars = []
for v in variables:
if v.name.endswith("gamma:0"):
grad_vars.append(dgamma)
elif v.name.endswith("fc1_kernel:0"):
grad_vars.append(fc1_wgrad)
elif v.name.endswith("fc1_bias:0"):
grad_vars.append(fc1_bias_grad)
elif v.name.endswith("fc2_kernel:0"):
grad_vars.append(fc2_wgrad)
elif v.name.endswith("fc2_bias:0") and self.use_bias:
grad_vars.append(fc2_bias_grad)
elif v.name.endswith("beta:0"):
grad_vars.append(dbeta)
return grad_inputs, grad_vars
if self.return_layernorm_output:
return (fc2_out, ln_out_return), grad_fn
return fc2_out, grad_fn
return fp8_layernorm_mlp_func(inp)
def call(
self,
inputs,
training=None,
):
"""
Apply layer normalization to the input followed by a feedforward network
(MLP Block).
Parameters
----------
inputs : tf.Tensor
Input tensor.
training : {True, False, None}, default = None
Whether this is in the training context.
"""
# self.pre_forward needs to be called outside the following branch,
# since it has side effects to set the self.fp8 if the autocast is
# detected.
training = self._get_training_value(training)
self.pre_forward(training, num_gemms=2)
inputmat = tf.reshape(inputs, shape=(-1, inputs.shape[-1]))
if self.fp8:
outputs = self.fp8_layernorm_mlp(
inputmat,
self.gamma,
self.beta,
self.fc1_kernel,
self.fc1_bias,
self.fc2_kernel,
self.fc2_bias,
)
else:
outputs = self.non_fp8_layernorm_mlp(
inputmat,
self.gamma,
self.beta,
self.fc1_kernel,
self.fc1_bias,
self.fc2_kernel,
self.fc2_bias,
)
if self.return_layernorm_output:
outputmat, ln_outputmat = outputs
else:
outputmat = outputs
outputs = tf.reshape(
outputmat, shape=(-1, *inputs.shape[1:-1], outputmat.shape[-1])
)
if self.return_bias:
if self.return_layernorm_output:
ln_outputs = tf.reshape(ln_outputmat, shape=inputs.shape)
return (outputs, self.fc2_bias, ln_outputs)
return outputs, self.fc2_bias
if self.return_layernorm_output:
ln_outputs = tf.reshape(ln_outputmat, shape=inputs.shape)
return (outputs, ln_outputs)
return outputs
def get_config(self):
"""Returns the config of the layer."""
config = super().get_config()
config.update(
{
"hidden_size": self.fc1_units,
"ffn_hidden_size": self.fc2_units,
"epsilon": self.epsilon,
"gamma_init_method": initializers.serialize(
self.gamma_initializer),
"beta_init_method": initializers.serialize(
self.beta_initializer),
"return_layernorm_output": self.return_layernorm_output,
"use_bias": self.use_bias,
"init_method": initializers.serialize(self.kernel1_initializer),
"output_layer_init_method": initializers.serialize(
self.kernel2_initializer
),
"bias_init_method": initializers.serialize(
self.bias_initializer),
}
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused scaled masked softmax functions"""
from typing import Callable
import os
import transformer_engine_tensorflow as tex
import tensorflow as tf
from .module import get_stream_id
THREADS_PER_WARP = 32
THREADS_PER_BLOCK = 128
_default_causal_mask = {}
def _get_default_causal_mask(sq: int) -> tf.Tensor:
"""Return the causal upper triangular mask for softmax input"""
if sq not in _default_causal_mask:
# In TF, the mask specifies 1 to keep and 0 to mask. In "causal" mask
# mode, we compute the softmax of the lower triangular.
mask_operator = tf.linalg.LinearOperatorLowerTriangular(
tf.ones((sq, sq), dtype=tf.bool))
mask = mask_operator.to_dense()
_default_causal_mask[sq] = mask
return _default_causal_mask[sq]
class FusedScaleMaskSoftmax(tf.keras.Model):
"""
fused operation: scaling + mask + softmax
Arguments:
attn_mask_type: attention mask type (pad or causal)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def __init__(
self,
attn_mask_type: str,
mask_func: Callable,
softmax_in_fp32: bool,
scale: float,
) -> None:
super().__init__()
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = bool(
int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1"))
)
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
self.stream = get_stream_id()
assert (
self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled"
def __call__(self, inp: tf.Tensor, mask: tf.Tensor) -> tf.Tensor:
"""FusedScaleMaskSoftmax fprop"""
# [b, np, sq, sk]
assert len(inp.shape) == 4
self.input_in_fp16 = inp.dtype == tf.float16
self.input_in_bf16 = inp.dtype == tf.bfloat16
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
if self.is_kernel_available(*inp.shape):
return self.forward_fused_softmax(inp, mask)
return self.forward_tf_softmax(inp, mask)
def is_kernel_available(self, b: int, np: int, sq: int, sk: int) -> bool:
"""Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(int(sk))
if self.attn_mask_type == "causal":
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
@tf.custom_gradient
def scaled_masked_softmax(self, x: tf.Tensor, mask: tf.Tensor,
scale: float):
"""Scaled masked softmax."""
y = tex.scaled_masked_softmax_forward(x, mask, scale, self.stream)
def grad_fn(upstream):
dx = tex.scaled_masked_softmax_backward(upstream, y, scale,
self.stream)
return dx, None, None
return y, grad_fn
@tf.custom_gradient
def scaled_softmax(self, x: tf.Tensor, scale: float):
"""Scaled softmax."""
y = tex.scaled_softmax_forward(x, scale, self.stream)
def grad_fn(upstream):
dx = tex.scaled_softmax_backward(upstream, y, scale, self.stream)
return dx, None
return y, grad_fn
@tf.custom_gradient
def scaled_upper_triang_masked_softmax(self, x: tf.Tensor, scale: float):
"""Scaled upper triangular masked softmax."""
y = tex.scaled_upper_triang_masked_softmax_forward(x, scale,
self.stream)
def grad_fn(upstream):
dx = tex.scaled_upper_triang_masked_softmax_backward(
upstream, y, scale, self.stream
)
return dx, None
return y, grad_fn
def forward_fused_softmax(
self,
inp: tf.Tensor,
mask: tf.Tensor,
) -> tf.Tensor:
"""Fused masked softmax kernel"""
sq, sk = inp.shape[2], inp.shape[3]
scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type == "causal":
assert sq == sk, "causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
inp = tf.reshape(inp, (-1, sq, sk))
probs = self.scaled_upper_triang_masked_softmax(inp, scale)
return tf.reshape(probs, inp.shape)
# input is 4D tensor (b, np, sq, sk)
if mask is not None:
# The mask defined in TE kernels are different from TF. In TE, the
# mask specifies 1 to mask out and 0 to keep.
mask = tf.math.logical_not(mask)
ndims = len(mask.shape)
assert ndims <= 4, "mask ndims should be <= 4"
if len(mask.shape) < 4:
# Broadcasting the first dims of mask to match the input ndims.
broadcast_shape = [1] * (4 - ndims) + mask.shape[:]
mask = tf.reshape(mask, broadcast_shape)
return self.scaled_masked_softmax(inp, mask, scale)
return self.scaled_softmax(inp, scale)
def forward_tf_softmax(self, inp: tf.Tensor, mask: tf.Tensor) -> tf.Tensor:
"""Framework softmax"""
if self.input_in_float16 and self.softmax_in_fp32:
inp = tf.cast(inp, tf.float32)
if self.scale is not None:
inp = inp * self.scale
if self.attn_mask_type == "causal":
mask = _get_default_causal_mask(inp.shape[2])
mask_output = self.mask_func(inp, mask) if mask is not None else inp
probs = tf.nn.softmax(mask_output, axis=-1)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = tf.cast(probs, tf.half)
else:
probs = tf.cast(probs, tf.bfloat16)
return probs
@staticmethod
def get_batch_per_block(key_seq_len: int) -> int:
"""Softmax utility"""
pow2 = 1 << (key_seq_len - 1).bit_length()
warp_size = pow2 if pow2 < THREADS_PER_WARP else THREADS_PER_WARP
batches_per_warp = 2 if pow2 <= 128 else 1
warps_per_block = THREADS_PER_BLOCK / warp_size
batches_per_block = warps_per_block * batches_per_warp
return batches_per_block
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer."""
from contextlib import nullcontext
from typing import Callable, Optional, Tuple, Union
import os
from keras import backend, layers, initializers
from keras.mixed_precision import autocast_variable
import tensorflow as tf
from transformer_engine.tensorflow import (
LayerNorm,
LayerNormDense,
LayerNormMLP,
Dense,
)
from .softmax import FusedScaleMaskSoftmax
from .constants import (
AttnMaskTypes,
AttnTypes,
LayerTypes,
)
from .utils import (
divide,
attention_mask_func,
)
from .jit import (
get_bias_dropout_add,
bias_dropout_add_fused_train,
bias_dropout_add_fused_inference,
)
class CoreAttention(tf.keras.Model): # pylint: disable=too-few-public-methods
"""Parallel attention w/o QKV and Proj Gemms
BMM1 -> softmax + dropout -> BMM2
"""
def __init__(
self,
num_attention_heads: int,
kv_channels: int,
attention_dropout: float,
layer_number: Optional[int] = None,
apply_query_key_layer_scaling: bool = True,
attention_softmax_in_fp32: bool = False,
attn_mask_type: str = "causal",
) -> None:
super().__init__()
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
if layer_number is None:
self.apply_query_key_layer_scaling = False
else:
self.layer_number = max(1, layer_number)
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.attn_mask_type = attn_mask_type
projection_size = kv_channels * num_attention_heads
assert (
attn_mask_type in AttnMaskTypes
), f"attn_mask_type {attn_mask_type} not supported"
# Per attention head and per partition values.
self.hidden_size_per_partition = divide(projection_size, 1)
self.hidden_size_per_attention_head = divide(
projection_size, num_attention_heads
)
self.attention_dropout_ctx = nullcontext
coeff = None
self.norm_factor = tf.math.sqrt(
float(self.hidden_size_per_attention_head))
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.attn_mask_type,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff,
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = layers.Dropout(attention_dropout)
def __call__(
self,
query_layer: tf.Tensor,
key_layer: tf.Tensor,
value_layer: tf.Tensor,
attention_mask: tf.Tensor,
) -> tf.Tensor:
"""core attention fprop"""
# [b, np, sq, sk]
output_size = (
query_layer.shape[1],
query_layer.shape[2],
query_layer.shape[0],
key_layer.shape[0],
)
# [sq, b, np, hn] -> [sq, b * np, hn]
new_q_shape = (output_size[2], output_size[0] * output_size[1], -1)
query_layer = tf.reshape(query_layer, new_q_shape)
# [sk, b, np, hn] -> [sk, b * np, hn]
new_k_shape = (output_size[3], output_size[0] * output_size[1], -1)
key_layer = tf.reshape(key_layer, new_k_shape)
norm_factor = self._maybe_cast_inputs(self.norm_factor)
# Raw attention scores. [b * np, sq, sk]
matmul_result = (
tf.matmul(
tf.transpose(query_layer, perm=(1, 0, 2)), # [b * np, sq, hn]
tf.transpose(key_layer, perm=(1, 2, 0)), # [b * np, hn, sk]
)
/ norm_factor
)
# change view to [b, np, sq, sk]
attention_scores = tf.reshape(matmul_result, output_size)
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with self.attention_dropout_ctx():
attention_probs = self.attention_dropout(attention_probs)
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
output_size = (
value_layer.shape[1],
value_layer.shape[2],
query_layer.shape[0],
value_layer.shape[3],
)
# change view [sk, b * np, hn]
new_v_shape = (value_layer.shape[0], output_size[0] * output_size[1],
-1)
value_layer = tf.reshape(value_layer, new_v_shape)
# change view [b * np, sq, sk]
new_attn_shape = (output_size[0] * output_size[1], output_size[2], -1)
attention_probs = tf.reshape(attention_probs, new_attn_shape)
# matmul: [b * np, sq, hn]
context_layer = tf.matmul(
attention_probs, # [b * np, sq, sk]
tf.transpose(value_layer, perm=(1, 0, 2)), # [b * np, sk, hn]
)
# change view [b, np, sq, hn]
context_layer = tf.reshape(context_layer, output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = tf.transpose(context_layer, perm=(2, 0, 1, 3))
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = (
*context_layer.shape[:-2],
self.hidden_size_per_partition,
)
context_layer = tf.reshape(context_layer, new_context_layer_shape)
return context_layer
class MultiHeadAttention(layers.Layer):
"""Parallel attention w/ QKV and Proj Gemms
BMM1 -> softmax + dropout -> BMM2
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
kv_channels: int,
attention_dropout: float,
layernorm_epsilon: float = 1e-3,
init_method: Optional[Callable] = None,
output_layer_init_method: Optional[Callable] = None,
layer_number: Optional[int] = None,
apply_query_key_layer_scaling: bool = True,
attention_softmax_in_fp32: bool = False,
attn_mask_type: str = "causal",
return_layernorm_output: bool = False,
input_layernorm: bool = False,
attention_type: str = "self",
fuse_qkv_params: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.layer_number = (layer_number,)
self.input_layernorm = input_layernorm
self.attention_type = attention_type
self.return_layernorm_output = return_layernorm_output
self.init_method = init_method
self.fuse_qkv_params = fuse_qkv_params
# We only support zero-initializer for bias weights.
self.bias_initializer = initializers.get("zeros")
assert (
attention_type in AttnTypes
), f"attention_type {attention_type} not supported"
self.hidden_size_per_attention_head = kv_channels
self.num_attention_heads_per_partition = divide(num_attention_heads, 1)
if self.attention_type == "self":
if self.input_layernorm:
self.layernorm_qkv = LayerNormDense(
3 * hidden_size,
epsilon=layernorm_epsilon,
kernel_initializer=init_method,
use_bias=True,
return_bias=False,
return_layernorm_output=return_layernorm_output,
skip_weight_param_allocation=not fuse_qkv_params,
)
else:
self.qkv = Dense(
3 * hidden_size,
kernel_initializer=init_method,
use_bias=True,
return_bias=False,
skip_weight_param_allocation=not fuse_qkv_params,
)
else:
if self.input_layernorm:
self.layernorm_query = LayerNormDense(
hidden_size,
epsilon=layernorm_epsilon,
kernel_initializer=init_method,
use_bias=True,
return_bias=False,
return_layernorm_output=return_layernorm_output,
skip_weight_param_allocation=not fuse_qkv_params,
)
else:
self.query_layer = Dense(
hidden_size,
kernel_initializer=init_method,
use_bias=True,
return_bias=False,
skip_weight_param_allocation=not fuse_qkv_params,
)
self.key_value = Dense(
2 * hidden_size,
kernel_initializer=init_method,
use_bias=True,
return_bias=False,
skip_weight_param_allocation=not fuse_qkv_params,
)
# Core Self attention.
self.core_attention = CoreAttention(
num_attention_heads,
kv_channels,
attention_dropout,
layer_number=layer_number,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32,
attn_mask_type=attn_mask_type,
)
# Linear
self.proj = Dense(
hidden_size,
kernel_initializer=output_layer_init_method,
use_bias=False,
return_bias=True,
)
def build(self, input_shape):
"""One-time allocation of the variables."""
input_shape = tf.TensorShape(input_shape)
last_dim = tf.compat.dimension_value(input_shape[-1])
if last_dim is None:
raise ValueError(
"The last dimension of the inputs to a Dense layer should be "
f"defined. Found None. Full input shape received: {input_shape}"
)
if not self.fuse_qkv_params:
self.set_qkv_params(
last_dim,
3 * self.hidden_size,
use_bias=True,
)
def set_qkv_params(
self,
in_features,
out_features,
use_bias: bool = False,
) -> None:
"""Initialize separate Parameters for query, key, and value tensors."""
assert (
out_features % 3 == 0
), f"3 way QKV split with dimension {out_features} not possible."
qkv_dim = out_features // 3
if self.attention_type == "self":
self.qkv_weight = self.add_weight(
name="qkv_kernel",
shape=(in_features, out_features),
initializer=self.init_method,
trainable=True,
)
self.qkv_bias = None
if use_bias:
self.qkv_bias = self.add_weight(
name="qkv_bias",
shape=(out_features,),
initializer=self.bias_initializer,
trainable=True,
)
else:
self.q_weight = self.add_weight(
name="q_kernel",
shape=(in_features, qkv_dim),
initializer=self.init_method,
trainable=True,
)
self.kv_weight = self.add_weight(
name="kv_kernel",
shape=(in_features, 2 * qkv_dim),
initializer=self.init_method,
trainable=True,
)
self.q_bias = None
self.kv_bias = None
if use_bias:
self.q_bias = self.add_weight(
name="q_bias",
shape=(qkv_dim,),
initializer=self.bias_initializer,
trainable=True,
)
self.kv_bias = self.add_weight(
name="kv_bias",
shape=(2 * qkv_dim,),
initializer=self.bias_initializer,
trainable=True,
)
def _get_training_value(self, training=None):
if training is None:
training = backend.learning_phase()
if isinstance(training, int):
training = bool(training)
if not self.trainable:
# When the layer is not trainable, it overrides the value passed
# from model.
training = False
return training
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
encoder_output: Optional[tf.Tensor] = None,
training: bool = None,
) -> Tuple[Union[tf.Tensor, None], ...]:
"""MultiHeadAttention FWD"""
training = self._get_training_value(training)
# hidden_states: [sq, b, h]
if attention_mask is not None:
assert (
attention_mask.dtype == tf.bool
), "Attention mask must be a boolean tensor"
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == "self":
qkv_weight = self.qkv_weight if not self.fuse_qkv_params else None
qkv_bias = self.qkv_bias if not self.fuse_qkv_params else None
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states,
kernel=qkv_weight,
bias=qkv_bias,
training=training,
)
if self.return_layernorm_output:
mixed_x_layer, layernorm_output = layernorm_qkv_outputs
else:
mixed_x_layer = layernorm_qkv_outputs
else:
mixed_x_layer = self.qkv(
hidden_states,
kernel=qkv_weight,
bias=qkv_bias,
training=training,
)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = (
*mixed_x_layer.shape[:-1],
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = tf.reshape(mixed_x_layer, new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
query_layer, key_layer, value_layer = tf.split(
mixed_x_layer, num_or_size_splits=3, axis=-1
)
else:
kv_weight = self.kv_weight if not self.fuse_qkv_params else None
kv_bias = self.kv_bias if not self.fuse_qkv_params else None
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer = self.key_value(
encoder_output,
kernel=kv_weight,
bias=kv_bias,
training=training,
)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = (
*mixed_kv_layer.shape[:-1],
self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head,
)
mixed_kv_layer = tf.reshape(mixed_kv_layer, new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
key_layer, value_layer = tf.split(
mixed_kv_layer, num_or_size_splits=2, axis=-1
)
# Attention head [sq, b, h] --> [sq, b, hp]
if self.input_layernorm:
layernorm_query_outputs = self.layernorm_query(
hidden_states,
kernel=self.q_weight,
bias=self.q_bias,
training=training,
)
if self.return_layernorm_output:
query_layer, layernorm_output = layernorm_query_outputs
else:
query_layer = layernorm_query_outputs
else:
query_layer = self.query_layer(
hidden_states,
kernel=self.q_weight,
bias=self.q_bias,
training=training,
)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = (
*query_layer.shape[:-1],
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
query_layer = tf.reshape(query_layer, new_tensor_shape)
# ==================================
# core attention computation
# ==================================
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask
)
# =================
# Output. [sq, b, h]
# =================
attention_output, attention_bias = self.proj(
context_layer,
training=training,
)
if self.input_layernorm and self.return_layernorm_output:
return attention_output, attention_bias, layernorm_output
return attention_output, attention_bias
class DropPath(tf.keras.Model): # pylint: disable=too-few-public-methods
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks).
"""
def __init__(self, drop_prob: float = 0.0) -> None:
super().__init__()
self.drop_prob = drop_prob
def __call__(self, hidden_state: tf.Tensor, training: bool) -> tf.Tensor:
"""DropPath FWD"""
if self.drop_prob == 0.0 or not training:
return hidden_state
keep_prob = 1 - self.drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (hidden_state.shape[0],) + (1,) * (len(hidden_state.shape) - 1)
# TODO(kaixih): We set the seed mainly for debugging purpose. Should
# allow users to turn it off.
random_tensor = tf.random.stateless_uniform(shape, seed=[1, 0])
random_mask = tf.cast(random_tensor <= keep_prob,
dtype=hidden_state.dtype)
output = (hidden_state / keep_prob) * random_mask
return output
class TransformerLayer(tf.keras.Model): # pylint: disable=too-few-public-methods
"""
TransformerLayer is made up of an attention block and a feedforward network
(MLP). This standard layer is based on the paper
"Attention Is All You Need".
Parameters
----------
hidden_size : int
size of each input sample.
ffn_hidden_size : int
intermediate size to which input samples are projected.
num_attention_heads : int
number of attention heads in the transformer layer.
layernorm_epsilon : float, default = 1e-5
a value added to the denominator of layer normalization for numerical
stability.
hidden_dropout: float, default = 0.1
dropout probability for the dropout op after FC2 layer.
attention_dropout: float, default = 0.1
dropout probability for the dropout op during multi-head attention.
init_method : Callable, default = `None`
used for initializing weights of QKV and FC1 weights in the following way:
`init_method(weight)`. When set to `None`, defaults to
`tf.keras.initializers.RandomNormal(mean=0.0, std=0.023)`.
output_layer_init_method : Callable, default = `None`
used for initializing weights of PROJ and FC2 in the following way:
`output_layer_init_method(weight)`. When set to `None`, defaults to
`tf.keras.initializers.RandomNormal(mean=0.0, std=0.023)`.
apply_residual_connection_post_layernorm : bool, default = `False`
if set to `True`, residual connections are taken from the output of layer
norm (default is taken from input of layer norm)
layer_number: int, default = `None`
layer number of the current `TransformerLayer` when multiple such modules
are concatenated to form a transformer block.
apply_query_key_layer_scaling: bool, default = `True`
apply query-key layer scaling during BMM1 by a factor of `layer_number`
output_layernorm: bool, default = `False`
if set to `True`, layer normalization is applied on the output side, after
the final dropout-add. default behavior is to apply layer normalization on
the input side, before the QKV transformation.
attention_softmax_in_fp32: bool, default = `False`
if set to `True`, softmax is executed in tf.float32 dtype (single
precision)
layer_type: {'encoder', 'decoder'}, default = `encoder`
if set to `decoder`, an additional cross-attn block is added after
self-attn. This can be used for structures like `T5` Transformer in
conjunction with the `encoder` option.
kv_channels: int, default = `None`
number of key-value channels. defaults to
`hidden_size / num_attention_heads` if `None`.
self_attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation.
Optimization parameters
-----------------------
drop_path_rate: float, default = 0.0
when > 0.0, applies stochastic depth per sample in the main path of the
residual block.
fuse_qkv_params: bool, default = 'False'
if set to `True`, `TransformerLayer` module exposes a single fused
parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits.
"""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_epsilon: float = 1e-5,
hidden_dropout: float = 0.1,
attention_dropout: float = 0.1,
init_method: Optional[Callable] = None,
output_layer_init_method: Optional[Callable] = None,
layer_number: Optional[int] = None,
kv_channels: Optional[int] = None,
self_attn_mask_type: str = "causal",
apply_query_key_layer_scaling: bool = True,
attention_softmax_in_fp32: bool = False,
apply_residual_connection_post_layernorm: bool = False,
output_layernorm: bool = False,
layer_type: str = "encoder",
drop_path_rate: float = 0.0,
fuse_qkv_params: bool = False,
) -> None:
super().__init__()
bias_dropout_fusion = \
bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
self.layer_number = layer_number
self.output_layernorm = output_layernorm
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm = (
apply_residual_connection_post_layernorm
)
assert (
self_attn_mask_type in AttnMaskTypes
), f"self_attn_mask_type {self_attn_mask_type} not supported"
assert layer_type in LayerTypes, \
f"layer_type {layer_type} not supported"
self.kv_channels = (
kv_channels if kv_channels else (hidden_size // num_attention_heads)
)
if init_method is None:
init_method = initializers.RandomNormal(mean=0.0, stddev=0.023)
if output_layer_init_method is None:
output_layer_init_method = initializers.RandomNormal(mean=0.0,
stddev=0.023)
attention_args = (
hidden_size,
num_attention_heads,
self.kv_channels,
attention_dropout,
layernorm_epsilon,
init_method,
output_layer_init_method,
)
common_attention_kwargs = {
"layer_number": layer_number,
"apply_query_key_layer_scaling": apply_query_key_layer_scaling,
"attention_softmax_in_fp32": attention_softmax_in_fp32,
"return_layernorm_output": apply_residual_connection_post_layernorm,
"fuse_qkv_params": fuse_qkv_params,
}
self.self_attention = MultiHeadAttention(
*attention_args,
**common_attention_kwargs,
attn_mask_type=self_attn_mask_type,
input_layernorm=not output_layernorm,
attention_type="self",
)
if layer_type == "decoder":
self.inter_attention = MultiHeadAttention(
*attention_args,
**common_attention_kwargs,
attn_mask_type="padding",
input_layernorm=True,
attention_type="cross",
)
# LayerNorm -> gelu(Linear + Bias) -> Linear
self.layernorm_mlp = LayerNormMLP(
hidden_size,
ffn_hidden_size,
epsilon=layernorm_epsilon,
kernel_initializer=init_method,
ffn_kernel_initializer=output_layer_init_method,
use_bias=False,
return_bias=True,
return_layernorm_output=apply_residual_connection_post_layernorm,
)
self.hidden_dropout = hidden_dropout
self.bias_dropout_fusion = bias_dropout_fusion
self.drop_path = (DropPath(drop_path_rate) if drop_path_rate > 0.0 else
None)
if self.output_layernorm:
self.layernorm = LayerNorm(
epsilon=layernorm_epsilon,
)
def _get_training_value(self, training=None):
if training is None:
training = backend.learning_phase()
if isinstance(training, int):
training = bool(training)
if not self.trainable:
# When the layer is not trainable, it overrides the value passed
# from model.
training = False
return training
def __call__(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
encoder_output: Optional[tf.Tensor] = None,
enc_dec_attn_mask: Optional[tf.Tensor] = None,
training: bool = None,
) -> tf.Tensor:
"""
Transformer Layer: attention block and a feedforward network (MLP)
Parameters
----------
hidden_states : tf.Tensor
Input tensor.
attention_mask : tf.Tensor
Boolean tensor used to mask out self-attention softmax input.
encoder_output : tf.Tensor
Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`.
enc_dec_attn_mask : tf.Tensor
Boolean tensor used to mask out inter-attention softmax input if using
`layer_type="decoder"`.
"""
if attention_mask is not None:
assert (
attention_mask.dtype == tf.bool
), "Attention mask must be a boolean tensor"
# Theoretically, the input dtype can be handled by the autocast during
# the layer call. However, we may use the input (hidden_states) in the
# residual connection before the layer is called. So, we convert it
# ahead of time. As for the other input (encoder_output), we can leave
# the conversion to the inter_attention layer, since it won't be used in
# the residual connection.
hidden_states = self._maybe_cast_inputs(hidden_states)
# Self attention.
self_attention_outputs = self.self_attention(
hidden_states,
attention_mask,
training=training,
)
if (self.apply_residual_connection_post_layernorm and
not self.output_layernorm):
attention_output, attention_bias, residual = self_attention_outputs
else:
attention_output, attention_bias = self_attention_outputs
residual = hidden_states
# Set BDA func.
if self.bias_dropout_fusion:
if training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(training)
# Bias dropout add.
# The autocast scope is used to enforce the correct dtype for the bias.
with autocast_variable.enable_auto_cast_variables(
self._compute_dtype_object):
if self.drop_path is None:
bda_output = bias_dropout_add_func(
attention_output,
attention_bias,
residual,
self.hidden_dropout,
)
else:
# TODO(kaixih): Use stateless_dropout and specify the seed
# mainly for debugging purpose. Should allow random seed.
out = (
tf.nn.experimental.stateless_dropout(
attention_output + attention_bias,
rate=self.hidden_dropout,
seed=[1, 0],
)
if training
else attention_output + attention_bias
)
bda_output = residual + self.drop_path(out, training)
# Cross attention.
if self.layer_type == "decoder":
inter_attention_outputs = self.inter_attention(
bda_output,
enc_dec_attn_mask,
encoder_output=encoder_output,
training=training,
)
if self.apply_residual_connection_post_layernorm:
attention_output, attention_bias, residual = \
inter_attention_outputs
else:
attention_output, attention_bias = inter_attention_outputs
residual = bda_output
# The autocast scope is used to enforce the correct dtype for the
# bias.
with autocast_variable.enable_auto_cast_variables(
self._compute_dtype_object
):
bda_output = bias_dropout_add_func(
attention_output,
attention_bias,
residual,
self.hidden_dropout,
)
# MLP.
mlp_outputs = self.layernorm_mlp(
bda_output,
training=training,
)
if self.apply_residual_connection_post_layernorm:
mlp_output, mlp_bias, residual = mlp_outputs
else:
mlp_output, mlp_bias = mlp_outputs
residual = bda_output
# Bias dropout add.
# The autocast scope is used to enforce the correct dtype for the bias.
with autocast_variable.enable_auto_cast_variables(
self._compute_dtype_object):
if self.drop_path is None:
output = bias_dropout_add_func(
mlp_output,
mlp_bias,
residual,
self.hidden_dropout,
)
else:
# TODO(kaixih): Use stateless_dropout and specify the seed
# mainly for debugging purpose. Should allow random seed.
output = (
tf.nn.experimental.stateless_dropout(
mlp_output + mlp_bias,
rate=self.hidden_dropout,
seed=[1, 0],
)
if training
else mlp_output + mlp_bias
)
output = residual + self.drop_path(output, training)
# For BERT like architectures.
if self.output_layernorm:
output = self.layernorm(output)
# output: [b, s, h]
return output
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utility functions for Transformer Engine modules"""
import tensorflow as tf
def attention_mask_func(
attention_scores: tf.Tensor, attention_mask: tf.Tensor
) -> tf.Tensor:
"""Get attention mask"""
return tf.where(attention_mask, attention_scores, -10000.0)
def ensure_divisibility(numerator: int, denominator: int) -> None:
"""Ensure that numerator is divisible by the denominator."""
assert (
numerator % denominator == 0
), f"{numerator} is not divisible by {denominator}"
def divide(numerator: int, denominator: int) -> int:
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment