Unverified Commit 4077ccc1 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] Custom Op Workspace Tensors from XLA Buffers (#532)



* Removed cudaMalloc/WorkspaceManager in JAX csrc. JAX custom ops now request buffers from XLA for their workspace tensors.
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* removed unused GEMM C++ API in TE-JAX
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed typo in layernorm_geglu_fp8_mlp and removed unnecessary shape reductions in primitives
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed import order for linting
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed custom op errors due to incorrect static arg nums in JAX jit
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* shifted cudnnSetStream further down the kernel to avoid error when executing dummy kernel call with nullptr stream
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed linting errors for blank lines
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent bd7fd0a6
...@@ -25,3 +25,5 @@ tests/cpp/build/ ...@@ -25,3 +25,5 @@ tests/cpp/build/
docs/_build docs/_build
.ipynb_checkpoints .ipynb_checkpoints
docs/doxygen docs/doxygen
*.log
CMakeFiles/CMakeSystem.cmake
\ No newline at end of file
...@@ -20,7 +20,7 @@ from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quanti ...@@ -20,7 +20,7 @@ from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quanti
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import layernrom_geglu_fp8_mlp from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp
GEMM_CASES = [ GEMM_CASES = [
(256, 256, 512), (256, 256, 512),
...@@ -196,7 +196,7 @@ class TestFP8Dot: ...@@ -196,7 +196,7 @@ class TestFP8Dot:
# out = (x * y) * z # out = (x * y) * z
fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) fp8_metas_scale_inv)
return jnp.mean(layernrom_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm")) return jnp.mean(layernorm_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm"))
def _convert_to_activation_function(fn_or_string): def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function.""" """Convert a string to an activation function."""
......
...@@ -59,8 +59,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -59,8 +59,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
cudnn_frontend::DataType_t tensorType, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
...@@ -248,6 +246,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -248,6 +246,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
return; return;
} }
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
// Build variant pack // Build variant pack
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = { std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{Q, devPtrQ}, {Q, devPtrQ},
...@@ -300,8 +302,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -300,8 +302,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
...@@ -519,6 +519,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -519,6 +519,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
return; return;
} }
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
// build variant pack // build variant pack
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = { std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{q, devPtrQ}, {q, devPtrQ},
......
...@@ -642,8 +642,6 @@ void fused_attn_max_512_fwd_impl( ...@@ -642,8 +642,6 @@ void fused_attn_max_512_fwd_impl(
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *workspace, size_t *workspace_size, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *workspace, size_t *workspace_size,
cudnnDataType_t tensorType, cudaStream_t stream, cudnnHandle_t handle) { cudnnDataType_t tensorType, cudaStream_t stream, cudnnHandle_t handle) {
try { try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
FADescriptor descriptor{b, h, FADescriptor descriptor{b, h,
s_q, s_kv, s_q, s_kv,
d, scaling_factor, d, scaling_factor,
...@@ -754,6 +752,10 @@ void fused_attn_max_512_fwd_impl( ...@@ -754,6 +752,10 @@ void fused_attn_max_512_fwd_impl(
return; return;
} }
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
// Prepare actual seqlen // Prepare actual seqlen
constexpr size_t nthreads_per_block = 128; constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
...@@ -845,9 +847,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -845,9 +847,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
size_t *workspace_size, cudnnDataType_t tensorType, size_t *workspace_size, cudnnDataType_t tensorType,
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
try { try {
// Create cudnn handle
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
FADescriptor descriptor{ FADescriptor descriptor{
b, h, s_q, s_kv, d, scaling_factor, true, dropout_probability, b, h, s_q, s_kv, d, scaling_factor, true, dropout_probability,
layout, bias_type, mask_type, tensorType, false}; layout, bias_type, mask_type, tensorType, false};
...@@ -1194,6 +1193,10 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -1194,6 +1193,10 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
return; return;
} }
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
constexpr size_t nthreads_per_block = 128; constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size; void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
......
...@@ -1007,8 +1007,6 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in ...@@ -1007,8 +1007,6 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
cudaStream_t stream, cudaStream_t stream,
cudnnHandle_t handle_) { cudnnHandle_t handle_) {
try { try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));
FADescriptor descriptor{ FADescriptor descriptor{
b, h, s_q, s_kv, d, b, h, s_q, s_kv, d,
attnScale, isTraining, dropoutProbability, layout, attnScale, isTraining, dropoutProbability, layout,
...@@ -1212,6 +1210,10 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in ...@@ -1212,6 +1210,10 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
return; return;
} }
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));
int32_t* qkv_ragged_offset = reinterpret_cast<int32_t*>( int32_t* qkv_ragged_offset = reinterpret_cast<int32_t*>(
reinterpret_cast<int8_t*>(workspace_ptr) + wkspace_size); reinterpret_cast<int8_t*>(workspace_ptr) + wkspace_size);
int32_t* o_ragged_offset = reinterpret_cast<int32_t*>( int32_t* o_ragged_offset = reinterpret_cast<int32_t*>(
...@@ -1324,8 +1326,6 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in ...@@ -1324,8 +1326,6 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
cudaStream_t stream, cudaStream_t stream,
cudnnHandle_t handle_) { cudnnHandle_t handle_) {
try { try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));
FADescriptor descriptor{ FADescriptor descriptor{
b, h, s_q, s_kv, d, b, h, s_q, s_kv, d,
attnScale, false, dropoutProbability, layout, attnScale, false, dropoutProbability, layout,
...@@ -1745,6 +1745,10 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in ...@@ -1745,6 +1745,10 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
return; return;
} }
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));
int32_t* qkv_ragged_offset = reinterpret_cast<int32_t*>( int32_t* qkv_ragged_offset = reinterpret_cast<int32_t*>(
reinterpret_cast<int8_t*>(workspace_ptr) + wkspace_size); reinterpret_cast<int8_t*>(workspace_ptr) + wkspace_size);
int32_t* o_ragged_offset = reinterpret_cast<int32_t*>( int32_t* o_ragged_offset = reinterpret_cast<int32_t*>(
......
...@@ -159,14 +159,6 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -159,14 +159,6 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const bool fp8_out = is_fp8_dtype(otype); const bool fp8_out = is_fp8_dtype(otype);
const auto ctype = layer_norm::DType::kFloat32; const auto ctype = layer_norm::DType::kFloat32;
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*mu, "mu");
CheckOutputTensor(*rsigma, "rsigma");
NVTE_CHECK(x.data.shape.size() == 2); NVTE_CHECK(x.data.shape.size() == 2);
const size_t rows = x.data.shape[0]; const size_t rows = x.data.shape[0];
...@@ -227,6 +219,16 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -227,6 +219,16 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
return; return;
} }
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*mu, "mu");
CheckOutputTensor(*rsigma, "rsigma");
if ( launch_params.barrier_size > 0 ) { if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->data.dptr; params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr); params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
...@@ -273,15 +275,6 @@ void layernorm_bwd(const Tensor& dz, ...@@ -273,15 +275,6 @@ void layernorm_bwd(const Tensor& dz,
auto otype = wtype; auto otype = wtype;
auto ctype = DType::kFloat32; auto ctype = DType::kFloat32;
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
CheckOutputTensor(*dbeta, "dbeta");
NVTE_CHECK(dz.data.dtype == otype); NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(mu.data.dtype == ctype); NVTE_CHECK(mu.data.dtype == ctype);
NVTE_CHECK(rsigma.data.dtype == ctype); NVTE_CHECK(rsigma.data.dtype == ctype);
...@@ -354,6 +347,16 @@ void layernorm_bwd(const Tensor& dz, ...@@ -354,6 +347,16 @@ void layernorm_bwd(const Tensor& dz,
return; return;
} }
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
CheckOutputTensor(*dbeta, "dbeta");
if ( launch_params.barrier_size > 0 ) { if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->data.dptr; params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr); params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
......
...@@ -113,12 +113,6 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -113,12 +113,6 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
const bool fp8_out = is_fp8_dtype(otype); const bool fp8_out = is_fp8_dtype(otype);
auto ctype = DType::kFloat32; auto ctype = DType::kFloat32;
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*rsigma, "rsigma");
NVTE_CHECK(x.data.shape.size() == 2); NVTE_CHECK(x.data.shape.size() == 2);
const size_t rows = x.data.shape[0]; const size_t rows = x.data.shape[0];
...@@ -172,6 +166,15 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -172,6 +166,15 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
return; return;
} }
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*rsigma, "rsigma");
if (launch_params.barrier_size > 0) { if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr; params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr); params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
...@@ -204,13 +207,6 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const ...@@ -204,13 +207,6 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
auto otype = wtype; auto otype = wtype;
auto ctype = DType::kFloat32; auto ctype = DType::kFloat32;
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
NVTE_CHECK(dz.data.dtype == otype); NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(rsigma.data.dtype == ctype); NVTE_CHECK(rsigma.data.dtype == ctype);
...@@ -268,6 +264,14 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const ...@@ -268,6 +264,14 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
return; return;
} }
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
if (launch_params.barrier_size > 0) { if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr; params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr); params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
......
This diff is collapsed.
...@@ -29,7 +29,6 @@ pybind11::dict Registrations() { ...@@ -29,7 +29,6 @@ pybind11::dict Registrations() {
dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8); dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8);
dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu); dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu);
dict["te_dgated_gelu_cast_transpose"] = EncapsulateFunction(DGatedGeluCastTranspose); dict["te_dgated_gelu_cast_transpose"] = EncapsulateFunction(DGatedGeluCastTranspose);
dict["te_gemm"] = EncapsulateFunction(Gemm);
dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward); dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward);
dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8); dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8);
dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward); dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward);
...@@ -56,14 +55,19 @@ pybind11::dict Registrations() { ...@@ -56,14 +55,19 @@ pybind11::dict Registrations() {
PYBIND11_MODULE(transformer_engine_jax, m) { PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("registrations", &Registrations); m.def("registrations", &Registrations);
m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor); m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor);
m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
m.def("get_fused_attn_backend", &GetFusedAttnBackend); m.def("get_fused_attn_backend", &GetFusedAttnBackend);
m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes);
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_self_fused_attn_fwd_workspace_sizes", &GetSelfFusedAttnForwardWorkspaceSizes);
m.def("get_self_fused_attn_bwd_workspace_sizes", &GetSelfFusedAttnBackwardWorkspaceSizes);
m.def("get_cross_fused_attn_fwd_workspace_sizes", &GetCrossFusedAttnForwardWorkspaceSizes);
m.def("get_cross_fused_attn_bwd_workspace_sizes", &GetCrossFusedAttnBackwardWorkspaceSizes);
pybind11::enum_<DType>(m, "DType", pybind11::module_local()) pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
......
This diff is collapsed.
...@@ -52,68 +52,69 @@ struct CustomCallCommonDescriptor { ...@@ -52,68 +52,69 @@ struct CustomCallCommonDescriptor {
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype, pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype); DType out_dtype);
struct CustomCallGemmDescriptor {
size_t m;
size_t n;
size_t k;
DType A_dtype;
DType B_dtype;
DType D_dtype;
bool transa;
bool transb;
bool use_split_accumulator;
};
pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, DType A_dtype,
DType B_dtype, DType D_dtype, bool transa, bool transb,
bool use_split_accumulator);
struct CustomCallNormDescriptor { struct CustomCallNormDescriptor {
size_t n; size_t batch_size;
size_t hidden; size_t hidden_size;
size_t wkspace_size;
size_t barrier_size;
size_t *dgamma_part_sizes; // 2D tensor
size_t *dbeta_part_sizes; // 2D tensor
DType x_dtype; DType x_dtype;
DType w_dtype; DType w_dtype;
DType wkspace_dtype;
DType barrier_dtype;
DType dgamma_part_dtype;
DType dbeta_part_dtype;
bool zero_centered_gamma; bool zero_centered_gamma;
float eps; float eps;
int sm_margin; int sm_margin;
}; };
pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype, pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
DType x_dtype, DType w_dtype,
DType wkspace_dtype, DType barrier_dtype,
DType dgamma_part_dtype, DType dbeta_part_dtype,
bool zero_centered_gamma, float eps, int sm_margin); bool zero_centered_gamma, float eps, int sm_margin);
struct SoftmaxDescriptor { struct SoftmaxDescriptor {
size_t batch; size_t batch_size;
size_t pad_batch; size_t padding_size;
size_t heads; size_t head_dim;
size_t q_seqlen; size_t q_seqlen;
size_t k_seqlen; size_t k_seqlen;
DType dtype; DType dtype;
float scale_factor; float scale_factor;
}; };
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, size_t heads, pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
size_t q_seqlen, size_t k_seqlen, DType dtype, size_t head_dim, size_t q_seqlen, size_t k_seqlen,
float scale_factor); DType dtype, float scale_factor);
struct CustomCallFusedAttnDescriptor { struct CustomCallFusedAttnDescriptor {
size_t batch; size_t batch_size;
size_t num_head;
size_t num_gqa_groups;
size_t q_max_seqlen; size_t q_max_seqlen;
size_t kv_max_seqlen; size_t kv_max_seqlen;
size_t num_heads;
size_t num_gqa_groups;
size_t head_dim; size_t head_dim;
size_t wkspace_size;
float scaling_factor; float scaling_factor;
float dropout_probability; float dropout_probability;
NVTE_Bias_Type bias_type; NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type; NVTE_Mask_Type mask_type;
DType dtype; DType dtype;
DType wkspace_dtype;
bool is_training; bool is_training;
}; };
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch, size_t num_head, size_t num_gqa_groups, size_t q_max_seqlen, size_t kv_max_seqlen, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t wkspace_size,
NVTE_Mask_Type mask_type, DType dtype, bool is_training); float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
DType dtype, DType wkspace_dtype, bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -135,13 +136,21 @@ void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t ...@@ -135,13 +136,21 @@ void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t
void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); pybind11::tuple GetLayerNormForwardWorkspaceSizes(
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma, float eps
);
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm,
bool zero_centered_gamma, float eps
);
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
...@@ -172,15 +181,41 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, ...@@ -172,15 +181,41 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len); std::size_t opaque_len);
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
......
...@@ -28,66 +28,6 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q ...@@ -28,66 +28,6 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend, size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream); cudaStream_t stream);
class WorkspaceManager {
public:
static WorkspaceManager &Instance() {
static thread_local WorkspaceManager instance;
return instance;
}
WorkspaceManager() {}
~WorkspaceManager() { Clear_(); }
void *GetWorkspace(size_t size = 4194304) {
ReallocateIfNeed_(size);
return workspace_;
}
template <typename... Args>
inline auto GetWorkspace(Args... args) {
auto asks = std::array<size_t, sizeof...(Args)>{args...};
std::array<size_t, sizeof...(Args) + 1> offsets = {0};
std::array<void *, sizeof...(Args)> workspaces = {nullptr};
std::transform_inclusive_scan(
asks.cbegin(), asks.cend(), offsets.begin() + 1, std::plus<size_t>{},
[=](auto x) { return PadSize_(x); }, 0);
auto *workspace = GetWorkspace(offsets.back());
std::transform(offsets.cbegin(), offsets.cend() - 1, workspaces.begin(),
[workspace](auto x) { return static_cast<char *>(workspace) + x; });
return workspaces;
}
private:
void *workspace_ = nullptr;
size_t size_ = 0;
size_t PadSize_(size_t size) {
constexpr size_t alignment = 128;
return ((size + alignment - 1) / alignment) * alignment;
}
void Clear_() {
if (workspace_ != nullptr) {
NVTE_CHECK_CUDA(cudaFree(workspace_));
}
workspace_ = nullptr;
size_ = 0;
}
void Allocate_(size_t new_size) {
new_size = PadSize_(new_size);
NVTE_CHECK_CUDA(cudaMalloc(&workspace_, new_size));
size_ = new_size;
}
void ReallocateIfNeed_(size_t new_size) {
if (new_size > size_) {
Clear_();
Allocate_(new_size);
}
}
};
class cudaDevicePropertiesManager { class cudaDevicePropertiesManager {
public: public:
static cudaDevicePropertiesManager &Instance() { static cudaDevicePropertiesManager &Instance() {
......
...@@ -22,7 +22,7 @@ from ..dot import type_safe_dot_general ...@@ -22,7 +22,7 @@ from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage from ..fp8 import FP8Helper, FP8MetaPackage
from ..layernorm import canonicalize_layernorm_type from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot from ..layernorm import layernorm, layernorm_fp8_dot
from ..mlp import layernrom_geglu_fp8_mlp, geglu from ..mlp import layernorm_geglu_fp8_mlp, geglu
from ..softmax import is_softmax_kernel_available from ..softmax import is_softmax_kernel_available
from ..softmax import softmax, SoftmaxType from ..softmax import softmax, SoftmaxType
...@@ -886,7 +886,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -886,7 +886,7 @@ class LayerNormMLP(TransformerEngineBase):
if use_fused_ln_mlp: if use_fused_ln_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment assert self.axis == -1 # Only support axis = =-1 at this moment
out = layernrom_geglu_fp8_mlp(y, out = layernorm_geglu_fp8_mlp(y,
scale, scale,
ln_bias, [kernel_1, kernel_2], ln_bias, [kernel_1, kernel_2],
fp8_meta_package, fp8_meta_package,
......
...@@ -55,7 +55,7 @@ def _geglu_bwd_rule(ctx, g): ...@@ -55,7 +55,7 @@ def _geglu_bwd_rule(ctx, g):
_geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule) _geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule)
def layernrom_geglu_fp8_mlp(x: jnp.ndarray, def layernorm_geglu_fp8_mlp(x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
beta: jnp.ndarray, beta: jnp.ndarray,
kernels: List[jnp.ndarray], kernels: List[jnp.ndarray],
...@@ -86,25 +86,25 @@ def layernrom_geglu_fp8_mlp(x: jnp.ndarray, ...@@ -86,25 +86,25 @@ def layernrom_geglu_fp8_mlp(x: jnp.ndarray,
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'" "if layernorm_type is 'rmsnorm'"
output = _layernrom_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale, output = _layernorm_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale,
scale_inv, fwd_dtype, bwd_dtype, layernorm_type, scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
zero_centered_gamma, epsilon) zero_centered_gamma, epsilon)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13)) @partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13))
def _layernrom_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, def _layernorm_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray, kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray,
amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str,
zero_centered_gamma: bool, epsilon: float): zero_centered_gamma: bool, epsilon: float):
output, _ = _layernrom_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, output, _ = _layernorm_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax,
scale, scale_inv, fwd_dtype, bwd_dtype, scale, scale_inv, fwd_dtype, bwd_dtype,
layernorm_type, zero_centered_gamma, epsilon) layernorm_type, zero_centered_gamma, epsilon)
return output return output
def _layernrom_geglu_fp8_mlp_fwd_rule( def _layernorm_geglu_fp8_mlp_fwd_rule(
x, x,
gamma, gamma,
beta, beta,
...@@ -209,7 +209,7 @@ def _layernrom_geglu_fp8_mlp_fwd_rule( ...@@ -209,7 +209,7 @@ def _layernrom_geglu_fp8_mlp_fwd_rule(
return dot_2_output, ctx return dot_2_output, ctx
def _layernrom_geglu_fp8_mlp_bwd_rule( def _layernorm_geglu_fp8_mlp_bwd_rule(
fwd_dtype, # pylint: disable=unused-argument fwd_dtype, # pylint: disable=unused-argument
bwd_dtype, bwd_dtype,
layernorm_type, layernorm_type,
...@@ -307,5 +307,5 @@ def _layernrom_geglu_fp8_mlp_bwd_rule( ...@@ -307,5 +307,5 @@ def _layernrom_geglu_fp8_mlp_bwd_rule(
fp8_max, amax, scale, scale_inv fp8_max, amax, scale, scale_inv
_layernrom_geglu_fp8_mlp.defvjp(_layernrom_geglu_fp8_mlp_fwd_rule, _layernorm_geglu_fp8_mlp.defvjp(_layernorm_geglu_fp8_mlp_fwd_rule,
_layernrom_geglu_fp8_mlp_bwd_rule) _layernorm_geglu_fp8_mlp_bwd_rule)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment