Unverified Commit 16f3f897 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Splitting `csrc/modules.cpp` by category (#883)



* categorized `csrc/modules.cpp`
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* adapted the build tool
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent c6ce2b8c
......@@ -6,8 +6,9 @@
from pathlib import Path
import setuptools
from glob import glob
from .utils import cuda_path
from .utils import cuda_path, all_files_in_dir
from typing import List
......@@ -19,11 +20,10 @@ def setup_jax_extension(
"""Setup PyBind11 extension for JAX support"""
# Source files
csrc_source_files = Path(csrc_source_files)
extensions_dir = csrc_source_files / "extensions"
sources = [
csrc_source_files / "extensions.cpp",
csrc_source_files / "modules.cpp",
csrc_source_files / "utils.cu",
]
] + all_files_in_dir(extensions_dir)
# Header files
cuda_home, _ = cuda_path()
......
......@@ -11,6 +11,9 @@
#include <cstddef>
#include <cstdint>
#include <vector>
#include <stdexcept>
#include <string>
#include <iostream>
#include <cuda_runtime_api.h>
#include <pybind11/pybind11.h>
......@@ -19,41 +22,54 @@
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/activation.h>
#include "common/common.h"
#include "common/util/logging.h"
#include "utils.h"
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <cudnn.h>
namespace transformer_engine {
namespace jax {
constexpr int kMaxNumDim = 8;
size_t get_activation_len(NVTE_Activation_Type activation_enum);
// TODO: Rename Shape to ???
struct Shape {
int num_dim;
size_t dims[kMaxNumDim];
void from_vector(const std::vector<size_t> &shape) {
num_dim = shape.size();
assert(num_dim <= kMaxNumDim);
std::memcpy(dims, shape.data(), num_dim * sizeof(size_t));
}
void from_vector(const std::vector<size_t> &shape);
std::vector<size_t> to_vector() const {
assert(num_dim <= kMaxNumDim);
std::vector<size_t> shape(num_dim);
std::memcpy(shape.data(), dims, num_dim * sizeof(size_t));
return shape;
}
std::vector<size_t> to_vector() const;
};
enum class NVTE_Activation_Enum {
GELU,
GEGLU,
SILU,
SWIGLU,
};
// Phuong: These 3 functions need to stay in the header file for compilation purpose
// 1.
inline bool use_fp8(DType type) {
return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2;
}
// 2.
template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) {
auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
return pybind11::bytes(str);
}
// 3.
template <typename T>
const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
if (opaque_len != sizeof(T)) {
throw std::runtime_error("Invalid opaque object size");
}
return reinterpret_cast<const T *>(opaque);
}
std::vector<size_t> MakeShapeVector(NVTEShape shape);
size_t get_activation_len(NVTE_Activation_Enum act_enum);
// Packing
struct CustomCallCommonDescriptor {
Shape shape;
......@@ -144,17 +160,22 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim);
// Transpose
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
// Activation
size_t get_activation_len(NVTE_Activation_Type activation_enum);
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......@@ -167,15 +188,11 @@ pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
// Normalization
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma,
......@@ -199,10 +216,14 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// Quantization
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// Softmax
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
......@@ -221,6 +242,15 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
// Attention
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
#include "transformer_engine/activation.h"
#include "transformer_engine/transpose.h"
namespace transformer_engine {
namespace jax {
size_t get_activation_len(NVTE_Activation_Type activation_enum) {
switch (activation_enum) {
case NVTE_Activation_Type::GELU: return 1;
case NVTE_Activation_Type::GEGLU: return 2;
case NVTE_Activation_Type::SILU: return 1;
case NVTE_Activation_Type::SWIGLU: return 2;
case NVTE_Activation_Type::RELU: return 1;
case NVTE_Activation_Type::REGLU: return 2;
case NVTE_Activation_Type::QGELU: return 1;
case NVTE_Activation_Type::QGEGLU: return 2;
case NVTE_Activation_Type::SRELU: return 1;
case NVTE_Activation_Type::SREGLU: return 2;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
return -1;
}
}
void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output,
NVTE_Activation_Type act_enum) {
auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape,
static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape,
static_cast<DType>(out_dtype), amax,
scale, scale_inverse);
switch (act_enum) {
case NVTE_Activation_Type::GELU:
nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::GEGLU:
nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_silu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_relu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_reglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_qgelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_srelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
}
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream,
nullptr, nullptr, output, act_enum);
}
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream,
scale_inv, amax_out, output, act_enum);
}
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *act_input = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n * act_len};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
switch (act_enum) {
case NVTE_Activation_Type::GELU:
nvte_dgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::GEGLU:
nvte_dgeglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_dsilu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_drelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_dqgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_dsrelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
}
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
TensorWrapper dummy_workspace;
// For now, all dbias_dact(-s) have the same workspace size
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *act_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
auto *dbias = buffers[7];
float *amax_out = reinterpret_cast<float *>(buffers[8]);
void *workspace_ptr = buffers[9];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
switch (act_enum) {
case NVTE_Activation_Type::GELU:
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
}
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *act_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
auto input_shape = desc.shape.to_vector();
auto act_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto output_trans_shape = std::vector<size_t>{n * 2, m};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
switch (act_enum) {
case NVTE_Activation_Type::GEGLU:
nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
}
} // namespace jax
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
namespace transformer_engine {
namespace jax {
std::vector<size_t> MakeShapeVector(NVTEShape shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}
void Shape::from_vector(const std::vector<size_t> &shape) {
num_dim = shape.size();
assert(num_dim <= kMaxNumDim);
std::memcpy(dims, shape.data(), num_dim * sizeof(size_t));
}
std::vector<size_t> Shape::to_vector() const {
assert(num_dim <= kMaxNumDim);
std::vector<size_t> shape(num_dim);
std::memcpy(shape.data(), dims, num_dim * sizeof(size_t));
return shape;
}
} // namespace jax
} // namespace transformer_engine
This diff is collapsed.
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
namespace transformer_engine {
namespace jax {
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype, size_t act_enum) {
CustomCallCommonDescriptor desc;
desc.shape.from_vector(shape);
desc.in_dtype = in_dtype;
desc.out_dtype = out_dtype;
desc.act_enum = act_enum;
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype,
DType out_dtype, DType wk_dtype,
size_t act_enum) {
CustomCallCommonWkDescriptor desc;
desc.shape.from_vector(shape);
desc.wkshape.from_vector(wkshape);
desc.in_dtype = in_dtype;
desc.out_dtype = out_dtype;
desc.wk_dtype = wk_dtype;
desc.act_enum = act_enum;
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallNormDescriptor(
size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) {
CustomCallNormDescriptor desc;
desc.batch_size = batch_size;
desc.hidden_size = hidden_size;
desc.wkspace_size = wkspace_size;
desc.barrier_size = barrier_size;
desc.dgamma_part_shape.from_vector(dgamma_part_shape);
desc.dbeta_part_shape.from_vector(dbeta_part_shape);
desc.x_dtype = x_dtype;
desc.w_dtype = w_dtype;
desc.wkspace_dtype = wkspace_dtype;
desc.barrier_dtype = barrier_dtype;
desc.dgamma_part_dtype = dgamma_part_dtype;
desc.dbeta_part_dtype = dbeta_part_dtype;
desc.zero_centered_gamma = zero_centered_gamma;
desc.eps = eps;
desc.sm_margin = sm_margin;
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
size_t head_dim, size_t q_seqlen, size_t k_seqlen,
DType dtype, float scale_factor) {
return PackOpaque(SoftmaxDescriptor{batch_size, padding_size, head_dim, q_seqlen, k_seqlen,
dtype, scale_factor});
}
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
bias_heads, head_dim, wkspace_size, scaling_factor, dropout_probability, bias_type,
mask_type, qkv_layout, dtype, wkspace_dtype, is_training});
}
} // namespace jax
} // namespace transformer_engine
......@@ -3,16 +3,8 @@
*
* See LICENSE for license information.
************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <cublasLt.h>
#include "common/include/transformer_engine/fused_attn.h"
#include "common/include/transformer_engine/activation.h"
#include "common/include/transformer_engine/transformer_engine.h"
#include "modules.h"
#include "utils.h"
#include "jax/csrc/extensions.h"
namespace transformer_engine {
namespace jax {
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
#include "transformer_engine/cast.h"
namespace transformer_engine {
namespace jax {
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *amax = reinterpret_cast<float *>(buffers[1]);
auto *scale = reinterpret_cast<float *>(buffers[2]);
auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
auto *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto shape = desc.shape.to_vector();
auto input_tensor = TensorWrapper(input, shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv);
nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream);
}
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *amax = reinterpret_cast<float *>(buffers[1]);
auto *scale = reinterpret_cast<float *>(buffers[2]);
auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto shape = desc.shape.to_vector();
auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv);
auto output_tensor = TensorWrapper(output, shape, desc.out_dtype);
nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
}
} // namespace jax
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
#include "transformer_engine/softmax.h"
namespace transformer_engine {
namespace jax {
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, shape, dtype);
auto output_tensor = TensorWrapper(output, shape, dtype);
nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), desc.scale_factor,
stream);
}
void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *grad_output = buffers[0];
auto *softmax_output = buffers[1];
auto *dgrad = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);
nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(),
dgrad_tensor.data(), desc.scale_factor, stream);
}
void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *mask = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto io_shape =
std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto mask_shape = std::vector<size_t>{desc.padding_size, 1, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, io_shape, dtype);
// Mask would be casted to uint8_t
auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte);
auto output_tensor = TensorWrapper(output, io_shape, dtype);
nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(),
output_tensor.data(), desc.scale_factor, stream);
}
void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
// The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax.
ScaledSoftmaxBackward(stream, buffers, opaque, opaque_len);
}
void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto attn_batch = desc.batch_size * desc.head_dim;
auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, shape, dtype);
auto output_tensor = TensorWrapper(output, shape, dtype);
nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(),
desc.scale_factor, stream);
}
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *grad_output = buffers[0];
auto *softmax_output = buffers[1];
auto *dgrad = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto attn_batch = desc.batch_size * desc.head_dim;
auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);
nvte_scaled_upper_triang_masked_softmax_backward(
grad_output_tensor.data(), softmax_output_tensor.data(), dgrad_tensor.data(),
desc.scale_factor, stream);
}
} // namespace jax
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
#include "transformer_engine/transpose.h"
namespace transformer_engine {
namespace jax {
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
void *output) {
auto input_shape = std::vector<size_t>{rows, cols};
auto output_shape = std::vector<size_t>{cols, rows};
auto input_tensor = TensorWrapper(input, input_shape, dtype);
auto transposed_tensor = TensorWrapper(output, output_shape, dtype);
nvte_transpose(input_tensor.data(), transposed_tensor.data(), stream);
}
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
void *input = buffers[0];
void *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto rows = desc.shape.dims[0];
auto cols = desc.shape.dims[1];
assert(desc.in_dtype == desc.out_dtype);
auto dtype = desc.out_dtype;
TransposeImpl(input, rows, cols, dtype, stream, output);
}
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *input_cast = buffers[4];
auto *input_cast_trans = buffers[5];
float *amax_out = reinterpret_cast<float *>(buffers[6]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto input_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto input_cast_tensor =
TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto input_cast_trans_tensor = TensorWrapper(input_cast_trans, input_trans_shape,
desc.out_dtype, amax_out, scale, scale_inv);
nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(),
input_cast_trans_tensor.data(), stream);
}
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
TensorWrapper dummy_workspace;
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
auto *output_trans = buffers[5];
auto *dbias = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
void *workspace_ptr = buffers[8];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
workspace.data(), stream);
}
} // namespace jax
} // namespace transformer_engine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment