Unverified Commit 16f3f897 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Splitting `csrc/modules.cpp` by category (#883)



* categorized `csrc/modules.cpp`
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* adapted the build tool
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent c6ce2b8c
...@@ -6,8 +6,9 @@ ...@@ -6,8 +6,9 @@
from pathlib import Path from pathlib import Path
import setuptools import setuptools
from glob import glob
from .utils import cuda_path from .utils import cuda_path, all_files_in_dir
from typing import List from typing import List
...@@ -19,11 +20,10 @@ def setup_jax_extension( ...@@ -19,11 +20,10 @@ def setup_jax_extension(
"""Setup PyBind11 extension for JAX support""" """Setup PyBind11 extension for JAX support"""
# Source files # Source files
csrc_source_files = Path(csrc_source_files) csrc_source_files = Path(csrc_source_files)
extensions_dir = csrc_source_files / "extensions"
sources = [ sources = [
csrc_source_files / "extensions.cpp",
csrc_source_files / "modules.cpp",
csrc_source_files / "utils.cu", csrc_source_files / "utils.cu",
] ] + all_files_in_dir(extensions_dir)
# Header files # Header files
cuda_home, _ = cuda_path() cuda_home, _ = cuda_path()
......
...@@ -11,6 +11,9 @@ ...@@ -11,6 +11,9 @@
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include <stdexcept>
#include <string>
#include <iostream>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
...@@ -19,41 +22,54 @@ ...@@ -19,41 +22,54 @@
#include <transformer_engine/fused_attn.h> #include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <transformer_engine/activation.h> #include <transformer_engine/activation.h>
#include "common/common.h"
#include "common/util/logging.h" #include "common/util/logging.h"
#include "utils.h"
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <cudnn.h>
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
constexpr int kMaxNumDim = 8; constexpr int kMaxNumDim = 8;
size_t get_activation_len(NVTE_Activation_Type activation_enum);
// TODO: Rename Shape to ???
struct Shape { struct Shape {
int num_dim; int num_dim;
size_t dims[kMaxNumDim]; size_t dims[kMaxNumDim];
void from_vector(const std::vector<size_t> &shape) { void from_vector(const std::vector<size_t> &shape);
num_dim = shape.size();
assert(num_dim <= kMaxNumDim);
std::memcpy(dims, shape.data(), num_dim * sizeof(size_t));
}
std::vector<size_t> to_vector() const { std::vector<size_t> to_vector() const;
assert(num_dim <= kMaxNumDim);
std::vector<size_t> shape(num_dim);
std::memcpy(shape.data(), dims, num_dim * sizeof(size_t));
return shape;
}
}; };
enum class NVTE_Activation_Enum { // Phuong: These 3 functions need to stay in the header file for compilation purpose
GELU, // 1.
GEGLU, inline bool use_fp8(DType type) {
SILU, return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2;
SWIGLU, }
}; // 2.
template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) {
auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
return pybind11::bytes(str);
}
// 3.
template <typename T>
const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
if (opaque_len != sizeof(T)) {
throw std::runtime_error("Invalid opaque object size");
}
return reinterpret_cast<const T *>(opaque);
}
std::vector<size_t> MakeShapeVector(NVTEShape shape);
size_t get_activation_len(NVTE_Activation_Enum act_enum); // Packing
struct CustomCallCommonDescriptor { struct CustomCallCommonDescriptor {
Shape shape; Shape shape;
...@@ -144,17 +160,22 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( ...@@ -144,17 +160,22 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training); bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, // Transpose
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim);
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
// Activation
size_t get_activation_len(NVTE_Activation_Type activation_enum);
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
...@@ -167,15 +188,11 @@ pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_ ...@@ -167,15 +188,11 @@ pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
// Normalization
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype, DType out_dtype, DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma, bool is_layer_norm, bool zero_centered_gamma,
...@@ -199,10 +216,14 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -199,10 +216,14 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// Quantization
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// Softmax
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque, void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len); std::size_t opaque_len);
...@@ -221,6 +242,15 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, ...@@ -221,6 +242,15 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len); std::size_t opaque_len);
// Attention
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes( pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
#include "transformer_engine/activation.h"
#include "transformer_engine/transpose.h"
namespace transformer_engine {
namespace jax {
size_t get_activation_len(NVTE_Activation_Type activation_enum) {
switch (activation_enum) {
case NVTE_Activation_Type::GELU: return 1;
case NVTE_Activation_Type::GEGLU: return 2;
case NVTE_Activation_Type::SILU: return 1;
case NVTE_Activation_Type::SWIGLU: return 2;
case NVTE_Activation_Type::RELU: return 1;
case NVTE_Activation_Type::REGLU: return 2;
case NVTE_Activation_Type::QGELU: return 1;
case NVTE_Activation_Type::QGEGLU: return 2;
case NVTE_Activation_Type::SRELU: return 1;
case NVTE_Activation_Type::SREGLU: return 2;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
return -1;
}
}
void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output,
NVTE_Activation_Type act_enum) {
auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape,
static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape,
static_cast<DType>(out_dtype), amax,
scale, scale_inverse);
switch (act_enum) {
case NVTE_Activation_Type::GELU:
nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::GEGLU:
nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_silu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_relu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_reglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_qgelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_srelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
}
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream,
nullptr, nullptr, output, act_enum);
}
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream,
scale_inv, amax_out, output, act_enum);
}
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *act_input = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n * act_len};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
switch (act_enum) {
case NVTE_Activation_Type::GELU:
nvte_dgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::GEGLU:
nvte_dgeglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_dsilu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_drelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_dqgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_dsrelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
}
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
TensorWrapper dummy_workspace;
// For now, all dbias_dact(-s) have the same workspace size
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *act_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
auto *dbias = buffers[7];
float *amax_out = reinterpret_cast<float *>(buffers[8]);
void *workspace_ptr = buffers[9];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
switch (act_enum) {
case NVTE_Activation_Type::GELU:
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
}
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *act_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
auto input_shape = desc.shape.to_vector();
auto act_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto output_trans_shape = std::vector<size_t>{n * 2, m};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
switch (act_enum) {
case NVTE_Activation_Type::GEGLU:
nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
}
} // namespace jax
} // namespace transformer_engine
...@@ -4,1126 +4,12 @@ ...@@ -4,1126 +4,12 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "modules.h" #include "jax/csrc/extensions.h"
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <cudnn.h>
#include <stdexcept>
#include <string>
#include <vector>
#include <iostream>
#include "common/common.h"
#include "common/util/logging.h"
#include "transformer_engine/activation.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/fused_attn.h" #include "transformer_engine/fused_attn.h"
#include "transformer_engine/layer_norm.h"
#include "transformer_engine/rmsnorm.h"
#include "transformer_engine/softmax.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h"
#include "utils.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
std::vector<size_t> MakeShapeVector(NVTEShape shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}
size_t get_activation_len(NVTE_Activation_Type activation_enum) {
switch (activation_enum) {
case NVTE_Activation_Type::GELU: return 1;
case NVTE_Activation_Type::GEGLU: return 2;
case NVTE_Activation_Type::SILU: return 1;
case NVTE_Activation_Type::SWIGLU: return 2;
case NVTE_Activation_Type::RELU: return 1;
case NVTE_Activation_Type::REGLU: return 2;
case NVTE_Activation_Type::QGELU: return 1;
case NVTE_Activation_Type::QGEGLU: return 2;
case NVTE_Activation_Type::SRELU: return 1;
case NVTE_Activation_Type::SREGLU: return 2;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
return -1;
}
}
template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) {
auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
return pybind11::bytes(str);
}
template <typename T>
const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
if (opaque_len != sizeof(T)) {
throw std::runtime_error("Invalid opaque object size");
}
return reinterpret_cast<const T *>(opaque);
}
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype, size_t act_enum) {
CustomCallCommonDescriptor desc;
desc.shape.from_vector(shape);
desc.in_dtype = in_dtype;
desc.out_dtype = out_dtype;
desc.act_enum = act_enum;
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype,
DType out_dtype, DType wk_dtype,
size_t act_enum) {
CustomCallCommonWkDescriptor desc;
desc.shape.from_vector(shape);
desc.wkshape.from_vector(wkshape);
desc.in_dtype = in_dtype;
desc.out_dtype = out_dtype;
desc.wk_dtype = wk_dtype;
desc.act_enum = act_enum;
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallNormDescriptor(
size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) {
CustomCallNormDescriptor desc;
desc.batch_size = batch_size;
desc.hidden_size = hidden_size;
desc.wkspace_size = wkspace_size;
desc.barrier_size = barrier_size;
desc.dgamma_part_shape.from_vector(dgamma_part_shape);
desc.dbeta_part_shape.from_vector(dbeta_part_shape);
desc.x_dtype = x_dtype;
desc.w_dtype = w_dtype;
desc.wkspace_dtype = wkspace_dtype;
desc.barrier_dtype = barrier_dtype;
desc.dgamma_part_dtype = dgamma_part_dtype;
desc.dbeta_part_dtype = dbeta_part_dtype;
desc.zero_centered_gamma = zero_centered_gamma;
desc.eps = eps;
desc.sm_margin = sm_margin;
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
size_t head_dim, size_t q_seqlen, size_t k_seqlen,
DType dtype, float scale_factor) {
return PackOpaque(SoftmaxDescriptor{batch_size, padding_size, head_dim, q_seqlen, k_seqlen,
dtype, scale_factor});
}
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
bias_heads, head_dim, wkspace_size, scaling_factor, dropout_probability, bias_type,
mask_type, qkv_layout, dtype, wkspace_dtype, is_training});
}
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
void *output) {
auto input_shape = std::vector<size_t>{rows, cols};
auto output_shape = std::vector<size_t>{cols, rows};
auto input_tensor = TensorWrapper(input, input_shape, dtype);
auto transposed_tensor = TensorWrapper(output, output_shape, dtype);
nvte_transpose(input_tensor.data(), transposed_tensor.data(), stream);
}
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
void *input = buffers[0];
void *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto rows = desc.shape.dims[0];
auto cols = desc.shape.dims[1];
assert(desc.in_dtype == desc.out_dtype);
auto dtype = desc.out_dtype;
TransposeImpl(input, rows, cols, dtype, stream, output);
}
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *input_cast = buffers[4];
auto *input_cast_trans = buffers[5];
float *amax_out = reinterpret_cast<float *>(buffers[6]);
NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto input_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto input_cast_tensor =
TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto input_cast_trans_tensor = TensorWrapper(input_cast_trans, input_trans_shape,
desc.out_dtype, amax_out, scale, scale_inv);
nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(),
input_cast_trans_tensor.data(), stream);
}
void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output,
NVTE_Activation_Type act_enum) {
auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape,
static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape,
static_cast<DType>(out_dtype), amax,
scale, scale_inverse);
switch (act_enum) {
case NVTE_Activation_Type::GELU:
nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::GEGLU:
nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_silu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_relu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_reglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_qgelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_srelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
}
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream,
nullptr, nullptr, output, act_enum);
}
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]);
NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream,
scale_inv, amax_out, output, act_enum);
}
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *act_input = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n * act_len};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
switch (act_enum) {
case NVTE_Activation_Type::GELU:
nvte_dgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::GEGLU:
nvte_dgeglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_dsilu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_drelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_dqgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_dsrelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
}
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
TensorWrapper dummy_workspace;
// For now, all dbias_dact(-s) have the same workspace size
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *act_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
auto *dbias = buffers[7];
float *amax_out = reinterpret_cast<float *>(buffers[8]);
void *workspace_ptr = buffers[9];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
switch (act_enum) {
case NVTE_Activation_Type::GELU:
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
}
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *act_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
auto input_shape = desc.shape.to_vector();
auto act_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto output_trans_shape = std::vector<size_t>{n * 2, m};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
switch (act_enum) {
case NVTE_Activation_Type::GEGLU:
nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
}
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
TensorWrapper dummy_workspace;
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
auto *output_trans = buffers[5];
auto *dbias = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
void *workspace_ptr = buffers[8];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
workspace.data(), stream);
}
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma,
float eps) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
// empty tensor wrappers are okay just to get workspace size
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, input_shape, out_dtype);
auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
// dummy tensor wrappers that will carry workspace size info later
TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
if (is_layer_norm) {
auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), nullptr,
num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
rsigma_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
dummy_barrier_tensor.data());
}
auto work_shape = MakeShapeVector(dummy_work_tensor.shape());
auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()),
std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()));
}
void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspace_size,
size_t barrier_size, bool zero_centered_gamma, float eps, void *input,
DType in_dtype, void *weight, DType w_dtype, void *bias, void *output,
DType out_dtype, void *workspace, DType work_dtype, void *barrier,
DType barrier_dtype, void *mu, void *rsigma, float *amax, float *scale,
float *scale_inv, cudaStream_t stream) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
auto workspace_shape = std::vector<size_t>{workspace_size};
auto barrier_shape = std::vector<size_t>{barrier_size};
auto is_layer_norm = (bias) ? true : false;
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto gamma_tensor = TensorWrapper(weight, weight_shape, in_dtype);
// assume output dtype = input dtype
// If we need mixed I/O precision in the future, we need an additional
// parameter for output type
auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv);
auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32);
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype);
auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
if (is_layer_norm) {
auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype);
auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);
layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream,
num_sm, workspace_tensor.data(), barrier_tensor.data());
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(),
barrier_tensor.data());
}
}
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype,
bool is_layer_norm, bool zero_centered_gamma,
float eps) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
auto intermediates_dtype = DType::kFloat32;
// empty tensor wrappers are okay just to get workspace size
auto dz_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
auto x_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto gamma_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
auto xgrad_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto wgrad_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
// dummy tensor wrappers that will carry workspace size info later
TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
// initialize dBeta information here -- layernorm will modify but RMSnorm will not
std::vector<size_t> dbeta_part_shape;
if (is_layer_norm) {
auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
wgrad_tensor.data(), dbeta_tensor.data(),
dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), nullptr,
num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
dbeta_part_shape = MakeShapeVector(dummy_dbeta_part_tensor.shape());
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dummy_dgamma_part_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
dummy_barrier_tensor.data());
dbeta_part_shape = std::vector<size_t>{0, 0};
}
auto work_shape = MakeShapeVector(dummy_work_tensor.shape());
auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape());
auto dgamma_part_shape = MakeShapeVector(dummy_dgamma_part_tensor.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()),
std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()),
std::make_pair(dgamma_part_shape, dummy_dgamma_part_tensor.dtype()),
std::make_pair(dbeta_part_shape, dummy_dbeta_part_tensor.dtype()));
}
void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size,
size_t barrier_size, Shape dgamma_part_shape, Shape dbeta_part_shape,
bool zero_centered_gamma, float eps, void *input, DType in_dtype,
void *weight, DType w_dtype, void *ograd, void *workspace,
DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu,
void *rsigma, void *xgrad, void *wgrad, void *dbeta, void *dgamma_part,
DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype,
cudaStream_t stream) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
auto intermediates_dtype = DType::kFloat32;
auto is_layer_norm = (dbeta) ? true : false;
// assume input type = output type
auto *grad_output = ograd;
auto x_dtype = in_dtype;
auto dz_tensor = TensorWrapper(grad_output, input_shape, x_dtype);
auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, intermediates_dtype);
auto *x = input;
auto x_tensor = TensorWrapper(x, input_shape, x_dtype);
auto gamma_tensor = TensorWrapper(weight, weight_shape, w_dtype);
auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype);
auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype);
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
auto workspace_shape = std::vector<size_t>{wkspace_size};
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
auto barrier_shape = std::vector<size_t>{barrier_size};
auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
auto dgamma_part_tensor =
TensorWrapper(dgamma_part, dgamma_part_shape.to_vector(), dgamma_dtype);
if (is_layer_norm) {
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
auto dbeta_part_tensor =
TensorWrapper(dbeta_part, dbeta_part_shape.to_vector(), dbeta_dtype);
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
wgrad_tensor.data(), dbeta_tensor.data(), dgamma_part_tensor.data(),
dbeta_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
barrier_tensor.data());
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dgamma_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
barrier_tensor.data());
}
}
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
auto *bias = buffers[2];
auto *amax = reinterpret_cast<float *>(buffers[3]);
auto *scale = reinterpret_cast<float *>(buffers[4]);
auto *scale_inv = reinterpret_cast<float *>(buffers[5]);
auto *output = buffers[6];
auto *mu = buffers[7];
auto *rsigma = buffers[8];
auto *amax_out = buffers[9];
auto *workspace = buffers[10];
auto *barrier = buffers[11];
NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
}
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
auto *bias = buffers[2];
auto *output = buffers[3];
auto *mu = buffers[4];
auto *rsigma = buffers[5];
auto *workspace = buffers[6];
auto *barrier = buffers[7];
float *amax = nullptr;
float *scale = nullptr;
float *scale_inv = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto out_dtype = in_dtype;
auto zero_centered_gamma = desc.zero_centered_gamma;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
}
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto dgamma_part_shape = desc.dgamma_part_shape;
auto dbeta_part_shape = desc.dbeta_part_shape;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto dgamma_part_dtype = desc.dgamma_part_dtype;
auto dbeta_part_dtype = desc.dbeta_part_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto *ograd = buffers[0];
auto *mu = buffers[1];
auto *rsigma = buffers[2];
auto *input = buffers[3];
auto *weight = buffers[4];
auto *xgrad = buffers[5];
auto *wgrad = buffers[6];
auto *dbeta = buffers[7];
auto *workspace = buffers[8];
auto *barrier = buffers[9];
auto *dgamma_part = buffers[10];
auto *dbeta_part = buffers[11];
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, stream);
}
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
auto *amax = reinterpret_cast<float *>(buffers[2]);
auto *scale = reinterpret_cast<float *>(buffers[3]);
auto *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *rsigma = buffers[6];
auto *amax_out = buffers[7];
auto *workspace = buffers[8];
auto *barrier = buffers[9];
NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
void *bias = nullptr;
void *mu = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
}
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
auto *output = buffers[2];
auto *rsigma = buffers[3];
auto *workspace = buffers[4];
auto *barrier = buffers[5];
void *bias = nullptr;
void *mu = nullptr;
float *amax = nullptr;
float *scale = nullptr;
float *scale_inv = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto out_dtype = in_dtype;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
}
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *ograd = buffers[0];
auto *rsigma = buffers[1];
auto *input = buffers[2];
auto *weight = buffers[3];
auto *xgrad = buffers[4];
auto *wgrad = buffers[5];
auto *workspace = buffers[6];
auto *barrier = buffers[7];
auto *dgamma_part = buffers[8];
void *mu = nullptr;
void *dbeta = nullptr;
void *dbeta_part = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto dgamma_part_shape = desc.dgamma_part_shape;
Shape dbeta_part_shape;
dbeta_part_shape.from_vector({0, 0});
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto dgamma_part_dtype = desc.dgamma_part_dtype;
auto dbeta_part_dtype = DType::kByte;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, stream);
}
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *amax = reinterpret_cast<float *>(buffers[1]);
auto *scale = reinterpret_cast<float *>(buffers[2]);
auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
auto *amax_out = reinterpret_cast<float *>(buffers[5]);
NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto shape = desc.shape.to_vector();
auto input_tensor = TensorWrapper(input, shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv);
nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream);
}
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *amax = reinterpret_cast<float *>(buffers[1]);
auto *scale = reinterpret_cast<float *>(buffers[2]);
auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto shape = desc.shape.to_vector();
auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv);
auto output_tensor = TensorWrapper(output, shape, desc.out_dtype);
nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
}
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, shape, dtype);
auto output_tensor = TensorWrapper(output, shape, dtype);
nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), desc.scale_factor,
stream);
}
void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *grad_output = buffers[0];
auto *softmax_output = buffers[1];
auto *dgrad = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);
nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(),
dgrad_tensor.data(), desc.scale_factor, stream);
}
void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *mask = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto io_shape =
std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto mask_shape = std::vector<size_t>{desc.padding_size, 1, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, io_shape, dtype);
// Mask would be casted to uint8_t
auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte);
auto output_tensor = TensorWrapper(output, io_shape, dtype);
nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(),
output_tensor.data(), desc.scale_factor, stream);
}
void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
// The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax.
ScaledSoftmaxBackward(stream, buffers, opaque, opaque_len);
}
void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto attn_batch = desc.batch_size * desc.head_dim;
auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, shape, dtype);
auto output_tensor = TensorWrapper(output, shape, dtype);
nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(),
desc.scale_factor, stream);
}
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *grad_output = buffers[0];
auto *softmax_output = buffers[1];
auto *dgrad = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto attn_batch = desc.batch_size * desc.head_dim;
auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);
nvte_scaled_upper_triang_masked_softmax_backward(
grad_output_tensor.data(), softmax_output_tensor.data(), dgrad_tensor.data(),
desc.scale_factor, stream);
}
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Mask_Type mask_type, float dropout_probability,
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
namespace transformer_engine {
namespace jax {
std::vector<size_t> MakeShapeVector(NVTEShape shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}
void Shape::from_vector(const std::vector<size_t> &shape) {
num_dim = shape.size();
assert(num_dim <= kMaxNumDim);
std::memcpy(dims, shape.data(), num_dim * sizeof(size_t));
}
std::vector<size_t> Shape::to_vector() const {
assert(num_dim <= kMaxNumDim);
std::vector<size_t> shape(num_dim);
std::memcpy(shape.data(), dims, num_dim * sizeof(size_t));
return shape;
}
} // namespace jax
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
#include "transformer_engine/layer_norm.h"
#include "transformer_engine/rmsnorm.h"
namespace transformer_engine {
namespace jax {
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma,
float eps) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
// empty tensor wrappers are okay just to get workspace size
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, input_shape, out_dtype);
auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
// dummy tensor wrappers that will carry workspace size info later
TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
if (is_layer_norm) {
auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), nullptr,
num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
rsigma_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
dummy_barrier_tensor.data());
}
auto work_shape = MakeShapeVector(dummy_work_tensor.shape());
auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()),
std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()));
}
void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspace_size,
size_t barrier_size, bool zero_centered_gamma, float eps, void *input,
DType in_dtype, void *weight, DType w_dtype, void *bias, void *output,
DType out_dtype, void *workspace, DType work_dtype, void *barrier,
DType barrier_dtype, void *mu, void *rsigma, float *amax, float *scale,
float *scale_inv, cudaStream_t stream) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
auto workspace_shape = std::vector<size_t>{workspace_size};
auto barrier_shape = std::vector<size_t>{barrier_size};
auto is_layer_norm = (bias) ? true : false;
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto gamma_tensor = TensorWrapper(weight, weight_shape, in_dtype);
// assume output dtype = input dtype
// If we need mixed I/O precision in the future, we need an additional
// parameter for output type
auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv);
auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32);
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype);
auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
if (is_layer_norm) {
auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype);
auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);
layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream,
num_sm, workspace_tensor.data(), barrier_tensor.data());
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(),
barrier_tensor.data());
}
}
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype,
bool is_layer_norm, bool zero_centered_gamma,
float eps) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
auto intermediates_dtype = DType::kFloat32;
// empty tensor wrappers are okay just to get workspace size
auto dz_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
auto x_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto gamma_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
auto xgrad_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto wgrad_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
// dummy tensor wrappers that will carry workspace size info later
TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
// initialize dBeta information here -- layernorm will modify but RMSnorm will not
std::vector<size_t> dbeta_part_shape;
if (is_layer_norm) {
auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
wgrad_tensor.data(), dbeta_tensor.data(),
dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), nullptr,
num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
dbeta_part_shape = MakeShapeVector(dummy_dbeta_part_tensor.shape());
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dummy_dgamma_part_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
dummy_barrier_tensor.data());
dbeta_part_shape = std::vector<size_t>{0, 0};
}
auto work_shape = MakeShapeVector(dummy_work_tensor.shape());
auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape());
auto dgamma_part_shape = MakeShapeVector(dummy_dgamma_part_tensor.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()),
std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()),
std::make_pair(dgamma_part_shape, dummy_dgamma_part_tensor.dtype()),
std::make_pair(dbeta_part_shape, dummy_dbeta_part_tensor.dtype()));
}
void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size,
size_t barrier_size, Shape dgamma_part_shape, Shape dbeta_part_shape,
bool zero_centered_gamma, float eps, void *input, DType in_dtype,
void *weight, DType w_dtype, void *ograd, void *workspace,
DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu,
void *rsigma, void *xgrad, void *wgrad, void *dbeta, void *dgamma_part,
DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype,
cudaStream_t stream) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
auto intermediates_dtype = DType::kFloat32;
auto is_layer_norm = (dbeta) ? true : false;
// assume input type = output type
auto *grad_output = ograd;
auto x_dtype = in_dtype;
auto dz_tensor = TensorWrapper(grad_output, input_shape, x_dtype);
auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, intermediates_dtype);
auto *x = input;
auto x_tensor = TensorWrapper(x, input_shape, x_dtype);
auto gamma_tensor = TensorWrapper(weight, weight_shape, w_dtype);
auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype);
auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype);
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
auto workspace_shape = std::vector<size_t>{wkspace_size};
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
auto barrier_shape = std::vector<size_t>{barrier_size};
auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
auto dgamma_part_tensor =
TensorWrapper(dgamma_part, dgamma_part_shape.to_vector(), dgamma_dtype);
if (is_layer_norm) {
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
auto dbeta_part_tensor =
TensorWrapper(dbeta_part, dbeta_part_shape.to_vector(), dbeta_dtype);
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
wgrad_tensor.data(), dbeta_tensor.data(), dgamma_part_tensor.data(),
dbeta_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
barrier_tensor.data());
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dgamma_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
barrier_tensor.data());
}
}
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
auto *bias = buffers[2];
auto *amax = reinterpret_cast<float *>(buffers[3]);
auto *scale = reinterpret_cast<float *>(buffers[4]);
auto *scale_inv = reinterpret_cast<float *>(buffers[5]);
auto *output = buffers[6];
auto *mu = buffers[7];
auto *rsigma = buffers[8];
auto *amax_out = buffers[9];
auto *workspace = buffers[10];
auto *barrier = buffers[11];
assert(amax_out == amax);
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
}
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
auto *bias = buffers[2];
auto *output = buffers[3];
auto *mu = buffers[4];
auto *rsigma = buffers[5];
auto *workspace = buffers[6];
auto *barrier = buffers[7];
float *amax = nullptr;
float *scale = nullptr;
float *scale_inv = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto out_dtype = in_dtype;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
}
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto dgamma_part_shape = desc.dgamma_part_shape;
auto dbeta_part_shape = desc.dbeta_part_shape;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto dgamma_part_dtype = desc.dgamma_part_dtype;
auto dbeta_part_dtype = desc.dbeta_part_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto *ograd = buffers[0];
auto *mu = buffers[1];
auto *rsigma = buffers[2];
auto *input = buffers[3];
auto *weight = buffers[4];
auto *xgrad = buffers[5];
auto *wgrad = buffers[6];
auto *dbeta = buffers[7];
auto *workspace = buffers[8];
auto *barrier = buffers[9];
auto *dgamma_part = buffers[10];
auto *dbeta_part = buffers[11];
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, stream);
}
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
auto *amax = reinterpret_cast<float *>(buffers[2]);
auto *scale = reinterpret_cast<float *>(buffers[3]);
auto *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *rsigma = buffers[6];
auto *amax_out = buffers[7];
auto *workspace = buffers[8];
auto *barrier = buffers[9];
assert(amax_out == amax);
void *bias = nullptr;
void *mu = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
}
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
auto *output = buffers[2];
auto *rsigma = buffers[3];
auto *workspace = buffers[4];
auto *barrier = buffers[5];
void *bias = nullptr;
void *mu = nullptr;
float *amax = nullptr;
float *scale = nullptr;
float *scale_inv = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = in_dtype;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
}
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *ograd = buffers[0];
auto *rsigma = buffers[1];
auto *input = buffers[2];
auto *weight = buffers[3];
auto *xgrad = buffers[4];
auto *wgrad = buffers[5];
auto *workspace = buffers[6];
auto *barrier = buffers[7];
auto *dgamma_part = buffers[8];
void *mu = nullptr;
void *dbeta = nullptr;
void *dbeta_part = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto dgamma_part_shape = desc.dgamma_part_shape;
Shape dbeta_part_shape;
dbeta_part_shape.from_vector({0, 0});
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto dgamma_part_dtype = desc.dgamma_part_dtype;
auto dbeta_part_dtype = DType::kByte;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, stream);
}
} // namespace jax
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
namespace transformer_engine {
namespace jax {
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype, size_t act_enum) {
CustomCallCommonDescriptor desc;
desc.shape.from_vector(shape);
desc.in_dtype = in_dtype;
desc.out_dtype = out_dtype;
desc.act_enum = act_enum;
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype,
DType out_dtype, DType wk_dtype,
size_t act_enum) {
CustomCallCommonWkDescriptor desc;
desc.shape.from_vector(shape);
desc.wkshape.from_vector(wkshape);
desc.in_dtype = in_dtype;
desc.out_dtype = out_dtype;
desc.wk_dtype = wk_dtype;
desc.act_enum = act_enum;
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallNormDescriptor(
size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) {
CustomCallNormDescriptor desc;
desc.batch_size = batch_size;
desc.hidden_size = hidden_size;
desc.wkspace_size = wkspace_size;
desc.barrier_size = barrier_size;
desc.dgamma_part_shape.from_vector(dgamma_part_shape);
desc.dbeta_part_shape.from_vector(dbeta_part_shape);
desc.x_dtype = x_dtype;
desc.w_dtype = w_dtype;
desc.wkspace_dtype = wkspace_dtype;
desc.barrier_dtype = barrier_dtype;
desc.dgamma_part_dtype = dgamma_part_dtype;
desc.dbeta_part_dtype = dbeta_part_dtype;
desc.zero_centered_gamma = zero_centered_gamma;
desc.eps = eps;
desc.sm_margin = sm_margin;
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
size_t head_dim, size_t q_seqlen, size_t k_seqlen,
DType dtype, float scale_factor) {
return PackOpaque(SoftmaxDescriptor{batch_size, padding_size, head_dim, q_seqlen, k_seqlen,
dtype, scale_factor});
}
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
bias_heads, head_dim, wkspace_size, scaling_factor, dropout_probability, bias_type,
mask_type, qkv_layout, dtype, wkspace_dtype, is_training});
}
} // namespace jax
} // namespace transformer_engine
...@@ -3,16 +3,8 @@ ...@@ -3,16 +3,8 @@
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <cublasLt.h> #include "jax/csrc/extensions.h"
#include "common/include/transformer_engine/fused_attn.h"
#include "common/include/transformer_engine/activation.h"
#include "common/include/transformer_engine/transformer_engine.h"
#include "modules.h"
#include "utils.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
#include "transformer_engine/cast.h"
namespace transformer_engine {
namespace jax {
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *amax = reinterpret_cast<float *>(buffers[1]);
auto *scale = reinterpret_cast<float *>(buffers[2]);
auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
auto *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto shape = desc.shape.to_vector();
auto input_tensor = TensorWrapper(input, shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv);
nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream);
}
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *amax = reinterpret_cast<float *>(buffers[1]);
auto *scale = reinterpret_cast<float *>(buffers[2]);
auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto shape = desc.shape.to_vector();
auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv);
auto output_tensor = TensorWrapper(output, shape, desc.out_dtype);
nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
}
} // namespace jax
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
#include "transformer_engine/softmax.h"
namespace transformer_engine {
namespace jax {
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, shape, dtype);
auto output_tensor = TensorWrapper(output, shape, dtype);
nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), desc.scale_factor,
stream);
}
void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *grad_output = buffers[0];
auto *softmax_output = buffers[1];
auto *dgrad = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);
nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(),
dgrad_tensor.data(), desc.scale_factor, stream);
}
void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *mask = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto io_shape =
std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto mask_shape = std::vector<size_t>{desc.padding_size, 1, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, io_shape, dtype);
// Mask would be casted to uint8_t
auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte);
auto output_tensor = TensorWrapper(output, io_shape, dtype);
nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(),
output_tensor.data(), desc.scale_factor, stream);
}
void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
// The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax.
ScaledSoftmaxBackward(stream, buffers, opaque, opaque_len);
}
void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto attn_batch = desc.batch_size * desc.head_dim;
auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, shape, dtype);
auto output_tensor = TensorWrapper(output, shape, dtype);
nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(),
desc.scale_factor, stream);
}
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *grad_output = buffers[0];
auto *softmax_output = buffers[1];
auto *dgrad = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto attn_batch = desc.batch_size * desc.head_dim;
auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);
nvte_scaled_upper_triang_masked_softmax_backward(
grad_output_tensor.data(), softmax_output_tensor.data(), dgrad_tensor.data(),
desc.scale_factor, stream);
}
} // namespace jax
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
#include "transformer_engine/transpose.h"
namespace transformer_engine {
namespace jax {
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
void *output) {
auto input_shape = std::vector<size_t>{rows, cols};
auto output_shape = std::vector<size_t>{cols, rows};
auto input_tensor = TensorWrapper(input, input_shape, dtype);
auto transposed_tensor = TensorWrapper(output, output_shape, dtype);
nvte_transpose(input_tensor.data(), transposed_tensor.data(), stream);
}
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
void *input = buffers[0];
void *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto rows = desc.shape.dims[0];
auto cols = desc.shape.dims[1];
assert(desc.in_dtype == desc.out_dtype);
auto dtype = desc.out_dtype;
TransposeImpl(input, rows, cols, dtype, stream, output);
}
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *input_cast = buffers[4];
auto *input_cast_trans = buffers[5];
float *amax_out = reinterpret_cast<float *>(buffers[6]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto input_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto input_cast_tensor =
TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto input_cast_trans_tensor = TensorWrapper(input_cast_trans, input_trans_shape,
desc.out_dtype, amax_out, scale, scale_inv);
nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(),
input_cast_trans_tensor.data(), stream);
}
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
TensorWrapper dummy_workspace;
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
auto *output_trans = buffers[5];
auto *dbias = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
void *workspace_ptr = buffers[8];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
workspace.data(), stream);
}
} // namespace jax
} // namespace transformer_engine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment