Unverified Commit 4077ccc1 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] Custom Op Workspace Tensors from XLA Buffers (#532)



* Removed cudaMalloc/WorkspaceManager in JAX csrc. JAX custom ops now request buffers from XLA for their workspace tensors.
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* removed unused GEMM C++ API in TE-JAX
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed typo in layernorm_geglu_fp8_mlp and removed unnecessary shape reductions in primitives
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed import order for linting
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed custom op errors due to incorrect static arg nums in JAX jit
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* shifted cudnnSetStream further down the kernel to avoid error when executing dummy kernel call with nullptr stream
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed linting errors for blank lines
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent bd7fd0a6
...@@ -25,3 +25,5 @@ tests/cpp/build/ ...@@ -25,3 +25,5 @@ tests/cpp/build/
docs/_build docs/_build
.ipynb_checkpoints .ipynb_checkpoints
docs/doxygen docs/doxygen
*.log
CMakeFiles/CMakeSystem.cmake
\ No newline at end of file
...@@ -20,7 +20,7 @@ from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quanti ...@@ -20,7 +20,7 @@ from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quanti
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import layernrom_geglu_fp8_mlp from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp
GEMM_CASES = [ GEMM_CASES = [
(256, 256, 512), (256, 256, 512),
...@@ -196,7 +196,7 @@ class TestFP8Dot: ...@@ -196,7 +196,7 @@ class TestFP8Dot:
# out = (x * y) * z # out = (x * y) * z
fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) fp8_metas_scale_inv)
return jnp.mean(layernrom_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm")) return jnp.mean(layernorm_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm"))
def _convert_to_activation_function(fn_or_string): def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function.""" """Convert a string to an activation function."""
......
...@@ -59,8 +59,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -59,8 +59,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
cudnn_frontend::DataType_t tensorType, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
...@@ -248,6 +246,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -248,6 +246,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
return; return;
} }
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
// Build variant pack // Build variant pack
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = { std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{Q, devPtrQ}, {Q, devPtrQ},
...@@ -300,8 +302,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -300,8 +302,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
...@@ -519,6 +519,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -519,6 +519,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
return; return;
} }
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
// build variant pack // build variant pack
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = { std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{q, devPtrQ}, {q, devPtrQ},
......
...@@ -642,8 +642,6 @@ void fused_attn_max_512_fwd_impl( ...@@ -642,8 +642,6 @@ void fused_attn_max_512_fwd_impl(
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *workspace, size_t *workspace_size, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *workspace, size_t *workspace_size,
cudnnDataType_t tensorType, cudaStream_t stream, cudnnHandle_t handle) { cudnnDataType_t tensorType, cudaStream_t stream, cudnnHandle_t handle) {
try { try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
FADescriptor descriptor{b, h, FADescriptor descriptor{b, h,
s_q, s_kv, s_q, s_kv,
d, scaling_factor, d, scaling_factor,
...@@ -754,6 +752,10 @@ void fused_attn_max_512_fwd_impl( ...@@ -754,6 +752,10 @@ void fused_attn_max_512_fwd_impl(
return; return;
} }
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
// Prepare actual seqlen // Prepare actual seqlen
constexpr size_t nthreads_per_block = 128; constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
...@@ -845,9 +847,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -845,9 +847,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
size_t *workspace_size, cudnnDataType_t tensorType, size_t *workspace_size, cudnnDataType_t tensorType,
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
try { try {
// Create cudnn handle
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
FADescriptor descriptor{ FADescriptor descriptor{
b, h, s_q, s_kv, d, scaling_factor, true, dropout_probability, b, h, s_q, s_kv, d, scaling_factor, true, dropout_probability,
layout, bias_type, mask_type, tensorType, false}; layout, bias_type, mask_type, tensorType, false};
...@@ -1194,6 +1193,10 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -1194,6 +1193,10 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
return; return;
} }
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
constexpr size_t nthreads_per_block = 128; constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size; void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
......
...@@ -1007,8 +1007,6 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in ...@@ -1007,8 +1007,6 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
cudaStream_t stream, cudaStream_t stream,
cudnnHandle_t handle_) { cudnnHandle_t handle_) {
try { try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));
FADescriptor descriptor{ FADescriptor descriptor{
b, h, s_q, s_kv, d, b, h, s_q, s_kv, d,
attnScale, isTraining, dropoutProbability, layout, attnScale, isTraining, dropoutProbability, layout,
...@@ -1212,6 +1210,10 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in ...@@ -1212,6 +1210,10 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
return; return;
} }
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));
int32_t* qkv_ragged_offset = reinterpret_cast<int32_t*>( int32_t* qkv_ragged_offset = reinterpret_cast<int32_t*>(
reinterpret_cast<int8_t*>(workspace_ptr) + wkspace_size); reinterpret_cast<int8_t*>(workspace_ptr) + wkspace_size);
int32_t* o_ragged_offset = reinterpret_cast<int32_t*>( int32_t* o_ragged_offset = reinterpret_cast<int32_t*>(
...@@ -1324,8 +1326,6 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in ...@@ -1324,8 +1326,6 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
cudaStream_t stream, cudaStream_t stream,
cudnnHandle_t handle_) { cudnnHandle_t handle_) {
try { try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));
FADescriptor descriptor{ FADescriptor descriptor{
b, h, s_q, s_kv, d, b, h, s_q, s_kv, d,
attnScale, false, dropoutProbability, layout, attnScale, false, dropoutProbability, layout,
...@@ -1745,6 +1745,10 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in ...@@ -1745,6 +1745,10 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
return; return;
} }
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));
int32_t* qkv_ragged_offset = reinterpret_cast<int32_t*>( int32_t* qkv_ragged_offset = reinterpret_cast<int32_t*>(
reinterpret_cast<int8_t*>(workspace_ptr) + wkspace_size); reinterpret_cast<int8_t*>(workspace_ptr) + wkspace_size);
int32_t* o_ragged_offset = reinterpret_cast<int32_t*>( int32_t* o_ragged_offset = reinterpret_cast<int32_t*>(
......
...@@ -159,14 +159,6 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -159,14 +159,6 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const bool fp8_out = is_fp8_dtype(otype); const bool fp8_out = is_fp8_dtype(otype);
const auto ctype = layer_norm::DType::kFloat32; const auto ctype = layer_norm::DType::kFloat32;
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*mu, "mu");
CheckOutputTensor(*rsigma, "rsigma");
NVTE_CHECK(x.data.shape.size() == 2); NVTE_CHECK(x.data.shape.size() == 2);
const size_t rows = x.data.shape[0]; const size_t rows = x.data.shape[0];
...@@ -227,6 +219,16 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -227,6 +219,16 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
return; return;
} }
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*mu, "mu");
CheckOutputTensor(*rsigma, "rsigma");
if ( launch_params.barrier_size > 0 ) { if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->data.dptr; params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr); params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
...@@ -273,15 +275,6 @@ void layernorm_bwd(const Tensor& dz, ...@@ -273,15 +275,6 @@ void layernorm_bwd(const Tensor& dz,
auto otype = wtype; auto otype = wtype;
auto ctype = DType::kFloat32; auto ctype = DType::kFloat32;
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
CheckOutputTensor(*dbeta, "dbeta");
NVTE_CHECK(dz.data.dtype == otype); NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(mu.data.dtype == ctype); NVTE_CHECK(mu.data.dtype == ctype);
NVTE_CHECK(rsigma.data.dtype == ctype); NVTE_CHECK(rsigma.data.dtype == ctype);
...@@ -354,6 +347,16 @@ void layernorm_bwd(const Tensor& dz, ...@@ -354,6 +347,16 @@ void layernorm_bwd(const Tensor& dz,
return; return;
} }
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
CheckOutputTensor(*dbeta, "dbeta");
if ( launch_params.barrier_size > 0 ) { if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->data.dptr; params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr); params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
......
...@@ -113,12 +113,6 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -113,12 +113,6 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
const bool fp8_out = is_fp8_dtype(otype); const bool fp8_out = is_fp8_dtype(otype);
auto ctype = DType::kFloat32; auto ctype = DType::kFloat32;
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*rsigma, "rsigma");
NVTE_CHECK(x.data.shape.size() == 2); NVTE_CHECK(x.data.shape.size() == 2);
const size_t rows = x.data.shape[0]; const size_t rows = x.data.shape[0];
...@@ -172,6 +166,15 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -172,6 +166,15 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
return; return;
} }
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*rsigma, "rsigma");
if (launch_params.barrier_size > 0) { if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr; params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr); params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
...@@ -204,13 +207,6 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const ...@@ -204,13 +207,6 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
auto otype = wtype; auto otype = wtype;
auto ctype = DType::kFloat32; auto ctype = DType::kFloat32;
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
NVTE_CHECK(dz.data.dtype == otype); NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(rsigma.data.dtype == ctype); NVTE_CHECK(rsigma.data.dtype == ctype);
...@@ -268,6 +264,14 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const ...@@ -268,6 +264,14 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
return; return;
} }
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
if (launch_params.barrier_size > 0) { if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr; params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr); params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
......
...@@ -10,13 +10,6 @@ import operator ...@@ -10,13 +10,6 @@ import operator
import os import os
import warnings import warnings
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend
import numpy as np import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
from jax.lib import xla_client from jax.lib import xla_client
...@@ -28,6 +21,13 @@ from jax.sharding import PartitionSpec, NamedSharding ...@@ -28,6 +21,13 @@ from jax.sharding import PartitionSpec, NamedSharding
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src import dispatch from jax._src import dispatch
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend
from .sharding import all_reduce_max_along_all_axes_except_PP from .sharding import all_reduce_max_along_all_axes_except_PP
from .sharding import all_reduce_sum_along_dp_fsdp from .sharding import all_reduce_sum_along_dp_fsdp
from .sharding import get_all_mesh_axes, num_of_devices from .sharding import get_all_mesh_axes, num_of_devices
...@@ -58,6 +58,7 @@ def te_dtype_to_jax_dtype(te_dtype): ...@@ -58,6 +58,7 @@ def te_dtype_to_jax_dtype(te_dtype):
TEDType.kInt64: jnp.int64, TEDType.kInt64: jnp.int64,
TEDType.kFloat8E4M3: jnp.float8_e4m3fn, TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
TEDType.kFloat8E5M2: jnp.float8_e5m2, TEDType.kFloat8E5M2: jnp.float8_e5m2,
TEDType.kByte: jnp.uint8
} }
if te_dtype not in converter: if te_dtype not in converter:
...@@ -94,6 +95,7 @@ def jax_dtype_to_te_dtype(jax_dtype): ...@@ -94,6 +95,7 @@ def jax_dtype_to_te_dtype(jax_dtype):
jnp.int64.dtype: TEDType.kInt64, jnp.int64.dtype: TEDType.kInt64,
jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3, jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3,
jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2, jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2,
jnp.uint8.dtype: TEDType.kByte,
} }
if jax_dtype not in converter: if jax_dtype not in converter:
...@@ -124,7 +126,7 @@ def _check_valid_batch_dims(bdims): ...@@ -124,7 +126,7 @@ def _check_valid_batch_dims(bdims):
class BasePrimitive(metaclass=ABCMeta): class BasePrimitive(metaclass=ABCMeta):
""" """
jax premitive jax primitive
""" """
@staticmethod @staticmethod
...@@ -135,6 +137,13 @@ class BasePrimitive(metaclass=ABCMeta): ...@@ -135,6 +137,13 @@ class BasePrimitive(metaclass=ABCMeta):
""" """
return NotImplemented return NotImplemented
@classmethod
def outer_abstract(cls, *args, **kwargs):
"""
optional abstract wrapper to eliminate workspace tensors
"""
return cls.abstract(*args, **kwargs)
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def lowering(): def lowering():
...@@ -196,7 +205,7 @@ def register_primitive(cls): ...@@ -196,7 +205,7 @@ def register_primitive(cls):
dispatch.prim_requires_devices_during_lowering.add(outer_p) dispatch.prim_requires_devices_during_lowering.add(outer_p)
outer_p.multiple_results = cls.multiple_results outer_p.multiple_results = cls.multiple_results
outer_p.def_impl(cls.impl) outer_p.def_impl(cls.impl)
outer_p.def_abstract_eval(cls.abstract) outer_p.def_abstract_eval(cls.outer_abstract)
batching.primitive_batchers[outer_p] = cls.batcher batching.primitive_batchers[outer_p] = cls.batcher
outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands, outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands,
...@@ -287,9 +296,9 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -287,9 +296,9 @@ class LayerNormFwdPrimitive(BasePrimitive):
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(x_aval, gamma_aval, beta_aval, **kwargs): # pylint: disable=unused-argument def abstract(x_aval, gamma_aval, beta_aval, **kwargs):
""" """
LayerNorm fwd abstract LayerNorm fwd inner primitive abstract
""" """
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
...@@ -303,6 +312,28 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -303,6 +312,28 @@ class LayerNormFwdPrimitive(BasePrimitive):
hidden_size = gamma_aval.size hidden_size = gamma_aval.size
assert x_aval.size % hidden_size == 0 assert x_aval.size % hidden_size == 0
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch size
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16)
True, kwargs['zero_centered_gamma'], kwargs['epsilon']
)
wkspace_aval = out_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
barrier_aval = out_aval.update(shape=barrier_info[0],
dtype=te_dtype_to_jax_dtype(barrier_info[1]))
return out_aval, mu_aval, rsigma_aval, wkspace_aval, barrier_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
LayerNorm fwd outer primitive abstract
"""
out_aval, mu_aval, rsigma_aval, _, _ = \
LayerNormFwdPrimitive.abstract(*args, **kwargs)
return out_aval, mu_aval, rsigma_aval return out_aval, mu_aval, rsigma_aval
@staticmethod @staticmethod
...@@ -333,10 +364,14 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -333,10 +364,14 @@ class LayerNormFwdPrimitive(BasePrimitive):
batch_shape = out_shape[:-1] batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval, barrier_aval = ctx.avals_out[-2:]
out_types = [ out_types = [
ir.RankedTensorType.get(out_shape, output_type), ir.RankedTensorType.get(out_shape, output_type),
ir.RankedTensorType.get(batch_shape, ir_mu_dtype), ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype))
] ]
operands = [x, gamma, beta] operands = [x, gamma, beta]
operand_shapes = [x_shape, g_shape, b_shape] operand_shapes = [x_shape, g_shape, b_shape]
...@@ -347,8 +382,16 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -347,8 +382,16 @@ class LayerNormFwdPrimitive(BasePrimitive):
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
hidden_size, hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
sm_margin, sm_margin,
...@@ -364,7 +407,7 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -364,7 +407,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
to describe implementation to describe implementation
""" """
assert LayerNormFwdPrimitive.inner_primitive is not None assert LayerNormFwdPrimitive.inner_primitive is not None
out, mu, rsigma = LayerNormFwdPrimitive.inner_primitive.bind( out, mu, rsigma, _, _ = LayerNormFwdPrimitive.inner_primitive.bind(
x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon) x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon)
return out, mu, rsigma return out, mu, rsigma
...@@ -449,9 +492,9 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -449,9 +492,9 @@ class LayerNormBwdPrimitive(BasePrimitive):
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs): # pylint: disable=unused-argument def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs):
""" """
Layernorm bwd abstract Layernorm bwd inner primitive abstract
""" """
w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype) w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype) mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype)
...@@ -464,6 +507,34 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -464,6 +507,34 @@ class LayerNormBwdPrimitive(BasePrimitive):
dx_aval = core.raise_to_shaped(dz_aval) dx_aval = core.raise_to_shaped(dz_aval)
dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval) dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval)
wkspace_info, barrier_info, dgamma_part_info, dbeta_part_info = \
transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
True, kwargs['zero_centered_gamma'], kwargs['epsilon']
)
wkspace_aval = dx_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
barrier_aval = dx_aval.update(shape=barrier_info[0],
dtype=te_dtype_to_jax_dtype(barrier_info[1]))
dgamma_part_aval = dgamma_aval.update(shape=dgamma_part_info[0],
dtype=te_dtype_to_jax_dtype(dgamma_part_info[1]))
dbeta_part_aval = dbeta_aval.update(shape=dbeta_part_info[0],
dtype=te_dtype_to_jax_dtype(dbeta_part_info[1]))
return dx_aval, dgamma_aval, dbeta_aval, wkspace_aval, barrier_aval, \
dgamma_part_aval, dbeta_part_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
LayerNorm bwd outer primitive abstract
"""
dx_aval, dgamma_aval, dbeta_aval, _, _, _, _ = \
LayerNormBwdPrimitive.abstract(*args, **kwargs)
return dx_aval, dgamma_aval, dbeta_aval return dx_aval, dgamma_aval, dbeta_aval
@staticmethod @staticmethod
...@@ -488,22 +559,32 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -488,22 +559,32 @@ class LayerNormBwdPrimitive(BasePrimitive):
hidden_size = reduce(operator.mul, g_shape) hidden_size = reduce(operator.mul, g_shape)
batch_size = reduce(operator.mul, x_shape) // hidden_size batch_size = reduce(operator.mul, x_shape) // hidden_size
out_types = [ out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type), ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
ir.RankedTensorType.get(g_shape, g_type.element_type), for output in ctx.avals_out
ir.RankedTensorType.get(b_shape, b_type.element_type),
] ]
operands = [dz, mu, rsigma, x, gamma] operands = [dz, mu, rsigma, x, gamma]
operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape] operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
wkspace_aval, barrier_aval, dgamma_part_aval, dbeta_part_aval = ctx.avals_out[-4:]
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
hidden_size, hidden_size,
wkspace_aval.size,
barrier_aval.size,
dgamma_part_aval.size,
dbeta_part_aval.size,
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
jax_dtype_to_te_dtype(dgamma_part_aval.dtype),
jax_dtype_to_te_dtype(dbeta_part_aval.dtype),
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
sm_margin, sm_margin,
...@@ -516,7 +597,7 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -516,7 +597,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
@staticmethod @staticmethod
def impl(dz, x, mu, rsigma, gamma, zero_centered_gamma, epsilon): def impl(dz, x, mu, rsigma, gamma, zero_centered_gamma, epsilon):
assert LayerNormBwdPrimitive.inner_primitive is not None assert LayerNormBwdPrimitive.inner_primitive is not None
dx, dgamma, dbeta = LayerNormBwdPrimitive.inner_primitive.bind( dx, dgamma, dbeta, _, _, _, _ = LayerNormBwdPrimitive.inner_primitive.bind(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon) dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon)
return dx, dgamma, dbeta return dx, dgamma, dbeta
...@@ -609,9 +690,9 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -609,9 +690,9 @@ class RmsNormFwdPrimitive(BasePrimitive):
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(x_aval, gamma_aval, **kwargs): # pylint: disable=unused-argument def abstract(x_aval, gamma_aval, **kwargs):
""" """
RMSNorm fwd abstract RMSNorm fwd inner primitive abstract
""" """
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
...@@ -624,6 +705,27 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -624,6 +705,27 @@ class RmsNormFwdPrimitive(BasePrimitive):
hidden_size = gamma_aval.size hidden_size = gamma_aval.size
assert x_aval.size % hidden_size == 0 assert x_aval.size % hidden_size == 0
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch size
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16)
False, False, kwargs['epsilon']
)
wkspace_aval = out_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
barrier_aval = out_aval.update(shape=barrier_info[0],
dtype=te_dtype_to_jax_dtype(barrier_info[1]))
return out_aval, rsigma_aval, wkspace_aval, barrier_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
RMSNorm fwd outer primitive abstract
"""
out_aval, rsigma_aval, _, _ = RmsNormFwdPrimitive.abstract(*args, **kwargs)
return out_aval, rsigma_aval return out_aval, rsigma_aval
@staticmethod @staticmethod
...@@ -643,9 +745,13 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -643,9 +745,13 @@ class RmsNormFwdPrimitive(BasePrimitive):
batch_shape = out_shape[:-1] batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval, barrier_aval = ctx.avals_out[-2:]
out_types = [ out_types = [
ir.RankedTensorType.get(out_shape, x_type.element_type), ir.RankedTensorType.get(out_shape, x_type.element_type),
ir.RankedTensorType.get(batch_shape, rsigma_element_type), ir.RankedTensorType.get(batch_shape, rsigma_element_type),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype))
] ]
operands = [x, gamma] operands = [x, gamma]
operand_shapes = [x_shape, g_shape] operand_shapes = [x_shape, g_shape]
...@@ -656,8 +762,16 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -656,8 +762,16 @@ class RmsNormFwdPrimitive(BasePrimitive):
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
hidden_size, hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma False, # RMSNorm doesn't support zero_centered_gamma
epsilon, epsilon,
sm_margin, sm_margin,
...@@ -673,7 +787,7 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -673,7 +787,7 @@ class RmsNormFwdPrimitive(BasePrimitive):
to describe implementation to describe implementation
""" """
assert RmsNormFwdPrimitive.inner_primitive is not None assert RmsNormFwdPrimitive.inner_primitive is not None
out, rsigma = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon) out, rsigma, _, _ = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon)
return out, rsigma return out, rsigma
@staticmethod @staticmethod
...@@ -744,15 +858,9 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -744,15 +858,9 @@ class RmsNormBwdPrimitive(BasePrimitive):
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract( def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs):
dz_aval,
x_aval,
rsigma_aval,
gamma_aval,
**kwargs # pylint: disable=unused-argument
):
""" """
RMSNorm bwd abstract RMSNorm bwd inner primitive abstract
""" """
w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype) w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype) rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype)
...@@ -764,6 +872,30 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -764,6 +872,30 @@ class RmsNormBwdPrimitive(BasePrimitive):
dx_aval = core.raise_to_shaped(dz_aval) dx_aval = core.raise_to_shaped(dz_aval)
dgamma_aval = core.raise_to_shaped(gamma_aval) dgamma_aval = core.raise_to_shaped(gamma_aval)
wkspace_info, barrier_info, dgamma_part_info, _ = \
transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
False, False, kwargs['epsilon']
)
wkspace_aval = dx_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
barrier_aval = dx_aval.update(shape=barrier_info[0],
dtype=te_dtype_to_jax_dtype(barrier_info[1]))
dgamma_part_aval = dgamma_aval.update(shape=dgamma_part_info[0],
dtype=te_dtype_to_jax_dtype(dgamma_part_info[1]))
return dx_aval, dgamma_aval, wkspace_aval, barrier_aval, dgamma_part_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
RMSNorm bwd outer primitive abstract
"""
dx_aval, dgamma_aval, _, _, _ = RmsNormBwdPrimitive.abstract(*args, **kwargs)
return dx_aval, dgamma_aval return dx_aval, dgamma_aval
@staticmethod @staticmethod
...@@ -782,9 +914,15 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -782,9 +914,15 @@ class RmsNormBwdPrimitive(BasePrimitive):
hidden_size = reduce(operator.mul, g_shape) hidden_size = reduce(operator.mul, g_shape)
batch_size = reduce(operator.mul, x_shape) // hidden_size batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval, barrier_aval, dgamma_part_aval = ctx.avals_out[-3:]
out_types = [ out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type), ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(g_shape, g_type.element_type), ir.RankedTensorType.get(g_shape, g_type.element_type),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)),
ir.RankedTensorType.get(dgamma_part_aval.shape,
jax_dtype_to_ir_dtype(dgamma_part_aval.dtype))
] ]
operands = [dz, rsigma, x, gamma] operands = [dz, rsigma, x, gamma]
operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape] operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape]
...@@ -795,8 +933,16 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -795,8 +933,16 @@ class RmsNormBwdPrimitive(BasePrimitive):
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
hidden_size, hidden_size,
wkspace_aval.size,
barrier_aval.size,
dgamma_part_aval.size,
0, # no dbeta_part for RMSnorm
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
jax_dtype_to_te_dtype(dgamma_part_aval.dtype),
TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma False, # RMSNorm doesn't support zero_centered_gamma
epsilon, epsilon,
sm_margin, sm_margin,
...@@ -809,7 +955,8 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -809,7 +955,8 @@ class RmsNormBwdPrimitive(BasePrimitive):
@staticmethod @staticmethod
def impl(dz, x, rsigma, gamma, epsilon): def impl(dz, x, rsigma, gamma, epsilon):
assert RmsNormBwdPrimitive.inner_primitive is not None assert RmsNormBwdPrimitive.inner_primitive is not None
dx, dgamma = RmsNormBwdPrimitive.inner_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon) dx, dgamma, _, _, _ = \
RmsNormBwdPrimitive.inner_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)
return dx, dgamma return dx, dgamma
@staticmethod @staticmethod
...@@ -1721,40 +1868,60 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -1721,40 +1868,60 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
def abstract(qkv_aval, bias_aval, seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, def abstract(qkv_aval, bias_aval, seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training): attn_mask_type, scaling_factor, dropout_probability, is_training):
""" """
Self fused attention fwd abstract Self fused attention fwd inner primitive abstract
""" """
# outer_primitve is seqlen, inner_primitive is cu_seqlen # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
del seqlen_or_cu_seqlen_aval, scaling_factor, is_training del seqlen_or_cu_seqlen_aval
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype) qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
*batch_shape, max_seqlen, nqkv, num_head, head_dim = qkv_aval.shape *batch_shape, max_seqlen, nqkv, num_heads, head_dim = qkv_aval.shape
assert nqkv == 3 assert nqkv == 3
assert qkv_aval.dtype == bias_aval.dtype assert qkv_aval.dtype == bias_aval.dtype
output_shape = (*batch_shape, max_seqlen, num_head, head_dim) output_shape = (*batch_shape, max_seqlen, num_heads, head_dim)
output_dtype = qkv_dtype out_aval = qkv_aval.update(shape=output_shape, dtype=qkv_dtype)
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type, backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type,
attn_mask_type, dropout_probability, num_head, num_head, attn_mask_type, dropout_probability, num_heads, num_heads,
max_seqlen, max_seqlen, head_dim).get_fused_attn_backend() max_seqlen, max_seqlen, head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_aux_shape = (*batch_shape, num_head, max_seqlen, max_seqlen) softmax_shape = (*batch_shape, num_heads, max_seqlen, max_seqlen)
softmax_dtype = qkv_dtype softmax_dtype = qkv_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_aux_shape = (*batch_shape, num_head, max_seqlen, 1) softmax_shape = (*batch_shape, num_heads, max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else: else:
raise ValueError(f'Unsupported {backend=}') raise ValueError(f'Unsupported {backend=}')
softmax_aux_aval = qkv_aval.update(shape=softmax_shape, dtype=softmax_dtype)
# JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
# 32-bit unsigned int to get the buffer size we need in the C++ kernel
checker = _FusedAttnRNGStateChecker() checker = _FusedAttnRNGStateChecker()
seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
assert seed_dtype == checker.rng_state_dtype assert seed_dtype == checker.rng_state_dtype
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_dtype = seed_dtype rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend
batch_size = reduce(operator.mul, batch_shape)
wkspace_info = transformer_engine_jax.get_self_fused_attn_fwd_workspace_sizes(
batch_size, max_seqlen, num_heads, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training
)
wkspace_aval = qkv_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
out_aval = qkv_aval.update(shape=output_shape, dtype=output_dtype) return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval
softmax_aux_aval = qkv_aval.update(shape=softmax_aux_shape, dtype=softmax_dtype)
rng_state_aval = qkv_aval.update(shape=rng_state_shape, dtype=rng_state_dtype) @staticmethod
def outer_abstract(*args, **kwargs):
"""
Self fused attention fwd outer primitive abstract
"""
out_aval, softmax_aux_aval, rng_state_aval, _ = \
SelfFusedAttnFwdPrimitive.abstract(*args, **kwargs)
return out_aval, softmax_aux_aval, rng_state_aval return out_aval, softmax_aux_aval, rng_state_aval
@staticmethod @staticmethod
...@@ -1763,23 +1930,25 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -1763,23 +1930,25 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
""" """
Self fused attention fwd lowering rules Self fused attention fwd lowering rules
""" """
qkv_aval, _, _, _ = ctx.avals_in
*batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape
batch = reduce(operator.mul, batch_shape)
operands = [qkv, bias, cu_seqlen, seed] operands = [qkv, bias, cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands) operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [ out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out for output in ctx.avals_out
] ]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
qkv_aval = ctx.avals_in[0]
*batch_shape, max_seqlen, _, num_heads, head_dim = qkv_aval.shape
batch_size = reduce(operator.mul, batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, batch_size, max_seqlen, max_seqlen, num_heads, num_heads, head_dim, wkspace_aval.size,
dropout_probability, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) jax_dtype_to_te_dtype(qkv_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
...@@ -1792,7 +1961,7 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -1792,7 +1961,7 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
cu_seqlen = generate_cu_seqlen(seqlen) cu_seqlen = generate_cu_seqlen(seqlen)
output, softmax_aux, rng_state = SelfFusedAttnFwdPrimitive.inner_primitive.bind( output, softmax_aux, rng_state, _ = SelfFusedAttnFwdPrimitive.inner_primitive.bind(
qkv, qkv,
bias, bias,
cu_seqlen, cu_seqlen,
...@@ -1897,16 +2066,35 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive): ...@@ -1897,16 +2066,35 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
""" """
Self fused attention bwd abstract Self fused attention bwd abstract
""" """
del softmax_aux_aval, rng_state_aval del softmax_aux_aval, rng_state_aval, seqlen_or_cu_seqlen_aval
# outer_primitve is seqlen, inner_primitive is cu_seqlen
del seqlen_or_cu_seqlen_aval, attn_bias_type, attn_mask_type assert qkv_aval.dtype == bias_aval.dtype == output_aval.dtype == doutput_aval.dtype
del scaling_factor, dropout_probability, is_training *batch_shape, max_seqlen, nqkv, num_heads, head_dim = qkv_aval.shape
assert nqkv == 3
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype) qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
assert qkv_aval.dtype == bias_aval.dtype == output_aval.dtype == doutput_aval.dtype
batch_size = reduce(operator.mul, batch_shape)
wkspace_shape, wkspace_dtype = \
transformer_engine_jax.get_self_fused_attn_bwd_workspace_sizes(
batch_size, max_seqlen, num_heads, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training
)
dqkv_aval = qkv_aval.update(shape=qkv_aval.shape, dtype=qkv_dtype) dqkv_aval = qkv_aval.update(shape=qkv_aval.shape, dtype=qkv_dtype)
dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
wkspace_aval = qkv_aval.update(shape=wkspace_shape,
dtype=te_dtype_to_jax_dtype(wkspace_dtype))
return dqkv_aval, dbias_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Self fused attention bwd outer primitive abstract
"""
dqkv_aval, dbias_aval, _ = SelfFusedAttnBwdPrimitive.abstract(*args, **kwargs)
return dqkv_aval, dbias_aval return dqkv_aval, dbias_aval
@staticmethod @staticmethod
...@@ -1915,24 +2103,25 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive): ...@@ -1915,24 +2103,25 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
""" """
Self fused attention bwd lowering rules Self fused attention bwd lowering rules
""" """
qkv_aval, _, _, _, _, _, _ = ctx.avals_in
*batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape
batch = reduce(operator.mul, batch_shape)
operands = [qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen] operands = [qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen]
operand_shapes = map(lambda x: x.type.shape, operands) operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [ out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out for output in ctx.avals_out
] ]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
qkv_aval = ctx.avals_in[0]
*batch_shape, max_seqlen, _, num_heads, head_dim = qkv_aval.shape
batch_size = reduce(operator.mul, batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, batch_size, max_seqlen, max_seqlen, num_heads, num_heads, head_dim, wkspace_aval.size,
dropout_probability, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) jax_dtype_to_te_dtype(qkv_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
...@@ -1945,7 +2134,7 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive): ...@@ -1945,7 +2134,7 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
cu_seqlen = generate_cu_seqlen(seqlen) cu_seqlen = generate_cu_seqlen(seqlen)
dqkv, dbias = SelfFusedAttnBwdPrimitive.inner_primitive.bind( dqkv, dbias, _ = SelfFusedAttnBwdPrimitive.inner_primitive.bind(
qkv, qkv,
bias, bias,
softmax_aux, softmax_aux,
...@@ -2067,50 +2256,62 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2067,50 +2256,62 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
""" """
Cross fused attention fwd abstract Cross fused attention fwd abstract
""" """
# outer_primitve is seqlen, inner_primitive is cu_seqlen
del scaling_factor, is_training
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
*q_batch_shape, q_max_seqlen, q_num_head, q_head_dim = q_aval.shape
kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype) kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
*kv_batch_shape, kv_max_seqlen, nkv, kv_num_head, kv_head_dim = kv_aval.shape
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
assert q_dtype == kv_dtype == bias_dtype assert q_dtype == kv_dtype == bias_dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, num_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = kv_aval.shape
assert q_batch_shape == kv_batch_shape assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim assert q_head_dim == kv_head_dim
assert nkv == 2 assert nkv == 2
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype out_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
output_shape = q_aval.shape
output_dtype = q_dtype
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD, backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
attn_bias_type, attn_mask_type, dropout_probability, q_num_head, attn_bias_type, attn_mask_type, dropout_probability,
kv_num_head, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
q_head_dim).get_fused_attn_backend() q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen) softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, kv_max_seqlen)
softmax_aux_dtype = q_dtype softmax_dtype = q_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, 1) softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, 1)
softmax_aux_dtype = dtypes.canonicalize_dtype(jnp.float32) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else: else:
raise ValueError(f'Unsupported {backend=}') raise ValueError(f'Unsupported {backend=}')
softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)
# JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
# 32-bit unsigned int to get the buffer size we need in the C++ kernel
checker = _FusedAttnRNGStateChecker() checker = _FusedAttnRNGStateChecker()
seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
assert seed_dtype == checker.rng_state_dtype assert seed_dtype == checker.rng_state_dtype
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_dtype = seed_dtype rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend
batch_size = reduce(operator.mul, q_batch_shape)
wkspace_info = transformer_engine_jax.get_cross_fused_attn_fwd_workspace_sizes(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training
)
wkspace_aval = q_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
out_aval = q_aval.update(shape=output_shape, dtype=output_dtype) return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval
softmax_aux_aval = q_aval.update(shape=softmax_aux_shape, dtype=softmax_aux_dtype)
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=rng_state_dtype)
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Cross fused attention fwd outer primitive abstract
"""
out_aval, softmax_aux_aval, rng_state_aval, _ = \
CrossFusedAttnFwdPrimitive.abstract(*args, **kwargs)
return out_aval, softmax_aux_aval, rng_state_aval return out_aval, softmax_aux_aval, rng_state_aval
@staticmethod @staticmethod
...@@ -2119,25 +2320,27 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2119,25 +2320,27 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
""" """
Cross fused attention fwd lowering rules Cross fused attention fwd lowering rules
""" """
q_aval, kv_aval, *_ = ctx.avals_in
assert q_aval.dtype == kv_aval.dtype
*batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
batch = reduce(operator.mul, batch_shape)
kv_max_seqlen, kv_num_head = kv_aval.shape[-4], kv_aval.shape[-2]
operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed] operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands) operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [ out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out for output in ctx.avals_out
] ]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, kv_aval, *_ = ctx.avals_in
*batch_shape, q_max_seqlen, num_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, _, num_gqa_groups, _ = kv_aval.shape
batch_size = reduce(operator.mul, batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, kv_num_head, q_max_seqlen, kv_max_seqlen, head_dim, batch_size, q_max_seqlen, kv_max_seqlen,
num_heads, num_gqa_groups, head_dim, wkspace_aval.size,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training) jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(CrossFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) out = custom_caller(CrossFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
...@@ -2151,7 +2354,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2151,7 +2354,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
q_cu_seqlen = generate_cu_seqlen(q_seqlen) q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
output, softmax_aux, rng_state = CrossFusedAttnFwdPrimitive.inner_primitive.bind( output, softmax_aux, rng_state, _ = CrossFusedAttnFwdPrimitive.inner_primitive.bind(
q, q,
kv, kv,
bias, bias,
...@@ -2266,7 +2469,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2266,7 +2469,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
Cross fused attention bwd abstract Cross fused attention bwd abstract
""" """
del softmax_aux_aval, rng_state_aval, output_aval del softmax_aux_aval, rng_state_aval, output_aval
del attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype) kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
...@@ -2274,9 +2477,35 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2274,9 +2477,35 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
assert q_dtype == kv_dtype == bias_dtype == doutput_dtype assert q_dtype == kv_dtype == bias_dtype == doutput_dtype
assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, num_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = kv_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert nkv == 2
batch_size = reduce(operator.mul, q_batch_shape)
wkspace_shape, wkspace_dtype = \
transformer_engine_jax.get_cross_fused_attn_bwd_workspace_sizes(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training
)
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
dkv_aval = kv_aval.update(shape=kv_aval.shape, dtype=kv_dtype) dkv_aval = kv_aval.update(shape=kv_aval.shape, dtype=kv_dtype)
dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
wkspace_aval = q_aval.update(shape=wkspace_shape,
dtype=te_dtype_to_jax_dtype(wkspace_dtype))
return dq_aval, dkv_aval, dbias_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Cross fused attention fwd outer primitive abstract
"""
dq_aval, dkv_aval, dbias_aval, _ = \
CrossFusedAttnBwdPrimitive.abstract(*args, **kwargs)
return dq_aval, dkv_aval, dbias_aval return dq_aval, dkv_aval, dbias_aval
@staticmethod @staticmethod
...@@ -2286,13 +2515,6 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2286,13 +2515,6 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
""" """
Cross fused attention bwd lowering rules Cross fused attention bwd lowering rules
""" """
q_aval, kv_aval, *_ = ctx.avals_in
assert q_aval.dtype == kv_aval.dtype
*batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
batch = reduce(operator.mul, batch_shape)
kv_max_seqlen, kv_num_head = kv_aval.shape[-4], kv_aval.shape[-2]
operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen] operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen]
operand_shapes = map(lambda x: x.type.shape, operands) operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [ out_types = [
...@@ -2302,12 +2524,19 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2302,12 +2524,19 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
# the dropout elements are encoded in the forward auxiliary tensor q_aval, kv_aval, *_ = ctx.avals_in
# so seed is not needed in backward *batch_shape, q_max_seqlen, num_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, _, num_gqa_groups, _ = kv_aval.shape
batch_size = reduce(operator.mul, batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, kv_num_head, q_max_seqlen, kv_max_seqlen, head_dim, batch_size, q_max_seqlen, kv_max_seqlen,
num_heads, num_gqa_groups, head_dim, wkspace_aval.size,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training) jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(CrossFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) out = custom_caller(CrossFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
...@@ -2321,7 +2550,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2321,7 +2550,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
q_cu_seqlen = generate_cu_seqlen(q_seqlen) q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
dq, dkv, dbias = CrossFusedAttnBwdPrimitive.inner_primitive.bind( dq, dkv, dbias, _ = CrossFusedAttnBwdPrimitive.inner_primitive.bind(
q, q,
kv, kv,
bias, bias,
...@@ -3143,9 +3372,8 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -3143,9 +3372,8 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
def abstract(x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, def abstract(x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
zero_centered_gamma, epsilon): zero_centered_gamma, epsilon):
""" """
LayerNorm fwd (fp8 out) abstract LayerNorm fwd (fp8 out) inner primitive abstract
""" """
del zero_centered_gamma, epsilon
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
...@@ -3157,10 +3385,32 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -3157,10 +3385,32 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
assert gamma_aval.size == beta_aval.size assert gamma_aval.size == beta_aval.size
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # in type
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight type
jax_dtype_to_te_dtype(out_dtype),
True, zero_centered_gamma, epsilon
)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
wkspace_aval = x_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
barrier_aval = x_aval.update(shape=barrier_info[0],
dtype=te_dtype_to_jax_dtype(barrier_info[1]))
return out_aval, mu_aval, rsigma_aval, updated_amax_aval, wkspace_aval, barrier_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
LayerNorm fwd (fp8 out) outer primitive abstract
"""
out_aval, mu_aval, rsigma_aval, updated_amax_aval, _, _ = \
LayerNormFwdFp8Primitive.abstract(*args, **kwargs)
return out_aval, mu_aval, rsigma_aval, updated_amax_aval return out_aval, mu_aval, rsigma_aval, updated_amax_aval
@staticmethod @staticmethod
...@@ -3204,11 +3454,15 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -3204,11 +3454,15 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
batch_shape = out_shape[:-1] batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval, barrier_aval = ctx.avals_out[-2:]
out_types = [ out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype), ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(batch_shape, ir_mu_dtype), ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype))
] ]
operands = [x, gamma, beta, amax, scale, scale_inv] operands = [x, gamma, beta, amax, scale, scale_inv]
operand_shapes = [ operand_shapes = [
...@@ -3221,8 +3475,16 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -3221,8 +3475,16 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
hidden_size, hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
sm_margin, sm_margin,
...@@ -3242,7 +3504,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -3242,7 +3504,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
to describe implementation to describe implementation
""" """
assert LayerNormFwdFp8Primitive.inner_primitive is not None assert LayerNormFwdFp8Primitive.inner_primitive is not None
out, mu, rsigma, updated_amax = LayerNormFwdFp8Primitive.inner_primitive.bind( out, mu, rsigma, updated_amax, _, _ = LayerNormFwdFp8Primitive.inner_primitive.bind(
x, x,
gamma, gamma,
beta, beta,
...@@ -3359,9 +3621,8 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -3359,9 +3621,8 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
@staticmethod @staticmethod
def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtype, epsilon): def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtype, epsilon):
""" """
RMSNorm fwd (fp8 out) abstract RMSNorm fwd (fp8 out) inner primitive abstract
""" """
del epsilon
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
...@@ -3374,10 +3635,31 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -3374,10 +3635,31 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
rsigama_dtype = jnp.float32 rsigama_dtype = jnp.float32
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch_size
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(out_dtype), # out te_dtype
False, False, epsilon
)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype) rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
amax_aval = out_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) amax_aval = out_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
wkspace_aval = x_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
barrier_aval = x_aval.update(shape=barrier_info[0],
dtype=te_dtype_to_jax_dtype(barrier_info[1]))
return out_aval, rsigma_aval, amax_aval, wkspace_aval, barrier_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
RMSNorm fwd (fp8 out) outer primitive abstract
"""
out_aval, rsigma_aval, amax_aval, _, _ = RmsNormFwdFp8Primitive.abstract(*args, **kwargs)
return out_aval, rsigma_aval, amax_aval return out_aval, rsigma_aval, amax_aval
@staticmethod @staticmethod
...@@ -3414,10 +3696,14 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -3414,10 +3696,14 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
batch_shape = out_shape[:-1] batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval, barrier_aval = ctx.avals_out[-2:]
out_types = [ out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype), ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype))
] ]
operands = [x, gamma, amax, scale, scale_inv] operands = [x, gamma, amax, scale, scale_inv]
operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
...@@ -3428,8 +3714,16 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -3428,8 +3714,16 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
hidden_size, hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma False, # RMSNorm doesn't support zero_centered_gamma
epsilon, epsilon,
sm_margin, sm_margin,
...@@ -3449,7 +3743,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -3449,7 +3743,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
to describe implementation to describe implementation
""" """
assert RmsNormFwdFp8Primitive.inner_primitive is not None assert RmsNormFwdFp8Primitive.inner_primitive is not None
out, rsigma, amax = RmsNormFwdFp8Primitive.inner_primitive.bind(x, out, rsigma, amax, _, _ = RmsNormFwdFp8Primitive.inner_primitive.bind(x,
gamma, gamma,
amax, amax,
scale, scale,
......
...@@ -29,7 +29,6 @@ pybind11::dict Registrations() { ...@@ -29,7 +29,6 @@ pybind11::dict Registrations() {
dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8); dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8);
dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu); dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu);
dict["te_dgated_gelu_cast_transpose"] = EncapsulateFunction(DGatedGeluCastTranspose); dict["te_dgated_gelu_cast_transpose"] = EncapsulateFunction(DGatedGeluCastTranspose);
dict["te_gemm"] = EncapsulateFunction(Gemm);
dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward); dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward);
dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8); dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8);
dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward); dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward);
...@@ -56,14 +55,19 @@ pybind11::dict Registrations() { ...@@ -56,14 +55,19 @@ pybind11::dict Registrations() {
PYBIND11_MODULE(transformer_engine_jax, m) { PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("registrations", &Registrations); m.def("registrations", &Registrations);
m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor); m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor);
m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
m.def("get_fused_attn_backend", &GetFusedAttnBackend); m.def("get_fused_attn_backend", &GetFusedAttnBackend);
m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes);
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_self_fused_attn_fwd_workspace_sizes", &GetSelfFusedAttnForwardWorkspaceSizes);
m.def("get_self_fused_attn_bwd_workspace_sizes", &GetSelfFusedAttnBackwardWorkspaceSizes);
m.def("get_cross_fused_attn_fwd_workspace_sizes", &GetCrossFusedAttnForwardWorkspaceSizes);
m.def("get_cross_fused_attn_bwd_workspace_sizes", &GetCrossFusedAttnBackwardWorkspaceSizes);
pybind11::enum_<DType>(m, "DType", pybind11::module_local()) pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "transformer_engine/activation.h" #include "transformer_engine/activation.h"
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
#include "transformer_engine/fused_attn.h" #include "transformer_engine/fused_attn.h"
#include "transformer_engine/gemm.h"
#include "transformer_engine/layer_norm.h" #include "transformer_engine/layer_norm.h"
#include "transformer_engine/rmsnorm.h" #include "transformer_engine/rmsnorm.h"
#include "transformer_engine/softmax.h" #include "transformer_engine/softmax.h"
...@@ -33,11 +32,12 @@ ...@@ -33,11 +32,12 @@
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
constexpr size_t kCublasLtForwardWorkspaceSize = 32 * 1024 * 1024;
constexpr size_t kCublasLtBackwardWorkspaceSize = 32 * 1024 * 1024;
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
std::vector<size_t> MakeShapeVector(NVTEShape shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}
template <typename T> template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) { pybind11::bytes PackOpaque(const T &descriptor) {
auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T)); auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
...@@ -61,33 +61,37 @@ pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, ...@@ -61,33 +61,37 @@ pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape,
return PackOpaque(desc); return PackOpaque(desc);
} }
pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, DType A_dtype, pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
DType B_dtype, DType D_dtype, bool transa, bool transb, size_t wkspace_size, size_t barrier_size,
bool use_split_accumulator) { size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
return PackOpaque(CustomCallGemmDescriptor{m, n, k, A_dtype, B_dtype, D_dtype, transa, transb, DType x_dtype, DType w_dtype,
use_split_accumulator}); DType wkspace_dtype, DType barrier_dtype,
} DType dgamma_part_dtype, DType dbeta_part_dtype,
pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype,
bool zero_centered_gamma, float eps, int sm_margin) { bool zero_centered_gamma, float eps, int sm_margin) {
return PackOpaque( return PackOpaque(CustomCallNormDescriptor{batch_size, hidden_size, wkspace_size, barrier_size,
CustomCallNormDescriptor{n, hidden, x_dtype, w_dtype, zero_centered_gamma, eps, sm_margin}); dgamma_part_sizes, dbeta_part_sizes,
x_dtype, w_dtype, wkspace_dtype, barrier_dtype,
dgamma_part_dtype, dbeta_part_dtype,
zero_centered_gamma, eps, sm_margin});
} }
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, size_t heads, pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
size_t q_seqlen, size_t k_seqlen, DType dtype, size_t head_dim, size_t q_seqlen, size_t k_seqlen,
float scale_factor) { DType dtype, float scale_factor) {
return PackOpaque( return PackOpaque(SoftmaxDescriptor{batch_size, padding_size, head_dim, q_seqlen, k_seqlen,
SoftmaxDescriptor{batch, pad_batch, heads, q_seqlen, k_seqlen, dtype, scale_factor}); dtype, scale_factor});
} }
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch, size_t num_head, size_t num_gqa_groups, size_t q_max_seqlen, size_t kv_max_seqlen, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t wkspace_size,
NVTE_Mask_Type mask_type, DType dtype, bool is_training) { float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
DType dtype, DType wkspace_dtype, bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{ return PackOpaque(CustomCallFusedAttnDescriptor{
batch, num_head, num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, scaling_factor, batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim, wkspace_size,
dropout_probability, bias_type, mask_type, dtype, is_training}); scaling_factor, dropout_probability, bias_type, mask_type, dtype, wkspace_dtype,
is_training});
} }
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream, void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
...@@ -247,48 +251,56 @@ void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *op ...@@ -247,48 +251,56 @@ void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *op
output_trans_tensor.data(), stream); output_trans_tensor.data(), stream);
} }
void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { pybind11::tuple GetLayerNormForwardWorkspaceSizes(
auto *A = buffers[0]; size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype,
auto *B = buffers[1]; bool is_layer_norm, bool zero_centered_gamma, float eps
auto *A_scale_inverse = reinterpret_cast<float *>(buffers[2]); ) {
auto *B_scale_inverse = reinterpret_cast<float *>(buffers[3]); auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto *D = buffers[4]; auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
// We transposes shape of A, B and D here to correctly invoke
// cuBlasLt GEMM (col-major) for row-major data. // empty tensor wrappers are okay just to get workspace size
const auto &desc = *UnpackOpaque<CustomCallGemmDescriptor>(opaque, opaque_len); auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype);
auto m = desc.m; auto output_tensor = TensorWrapper(nullptr, input_shape, out_dtype);
auto n = desc.n; auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
auto k = desc.k;
auto A_shape = std::vector<size_t>{k, m}; // dummy tensor wrappers that will carry workspace size info later
auto A_tensor = TensorWrapper(A, A_shape, desc.A_dtype, nullptr, nullptr, A_scale_inverse); TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto B_shape = std::vector<size_t>{n, k}; auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
auto B_tensor = TensorWrapper(B, B_shape, desc.B_dtype, nullptr, nullptr, B_scale_inverse); if (is_layer_norm) {
auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
auto D_shape = std::vector<size_t>{n, m}; auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
auto D_tensor = TensorWrapper(D, D_shape, desc.D_dtype);
auto null_tensor = TensorWrapper(nullptr, std::vector<size_t>{0}, DType::kFloat32);
size_t workspace_size = kCublasLtForwardWorkspaceSize; layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size); output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), nullptr,
auto wk_tensor = TensorWrapper(workspace, std::vector<size_t>{workspace_size}, DType::kByte); num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
rsigma_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
dummy_barrier_tensor.data());
}
nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), null_tensor.data(), auto work_shape = MakeShapeVector(dummy_work_tensor.shape());
null_tensor.data(), (desc.transa) ? CUBLAS_OP_T : CUBLAS_OP_N, auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape());
(desc.transb) ? CUBLAS_OP_T : CUBLAS_OP_N, false, wk_tensor.data(), false, return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()),
desc.use_split_accumulator, 0, stream); std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()));
} }
void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps, void LayerNormForwardImpl(size_t batch_size, size_t hidden_size,
int sm_margin, void *input, DType in_dtype, void *weight, DType w_dtype, size_t workspace_size, size_t barrier_size,
void *bias, void *output, DType out_dtype, void *mu, void *rsigma, bool zero_centered_gamma, float eps, void *input, DType in_dtype,
float *amax, float *scale, float *scale_inv, cudaStream_t stream) { void *weight, DType w_dtype, void *bias, void *output, DType out_dtype,
auto input_shape = std::vector<size_t>{n, hidden}; void *workspace, DType work_dtype, void *barrier, DType barrier_dtype,
auto weight_shape = std::vector<size_t>{hidden}; void *mu, void *rsigma, float *amax, float *scale, float *scale_inv,
auto intermediates_shape = std::vector<size_t>{n}; cudaStream_t stream) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
auto workspace_shape = std::vector<size_t>{workspace_size};
auto barrier_shape = std::vector<size_t>{barrier_size};
auto is_layer_norm = (bias) ? true : false; auto is_layer_norm = (bias) ? true : false;
auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
...@@ -300,63 +312,95 @@ void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, flo ...@@ -300,63 +312,95 @@ void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, flo
auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv); auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv);
auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32); auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32);
// Create uninitialized workspace, barrier and init them on the first auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
if (!is_layer_norm) {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
}
// The first call is to query the required workspace auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype);
auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
if (is_layer_norm) { if (is_layer_norm) {
auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype); auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype);
auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32); auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);
layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream, output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream,
num_sm, dummy_workspace_tensor.data(), dummy_barrier_tensor.data()); num_sm, workspace_tensor.data(), barrier_tensor.data());
} else { } else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(), nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
rsigma_tensor.data(), stream, num_sm, dummy_workspace_tensor.data(), rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(),
dummy_barrier_tensor.data()); barrier_tensor.data());
} }
}
size_t workspace_size = pybind11::tuple GetLayerNormBackwardWorkspaceSizes(
dummy_workspace_tensor.shape().data[0] * typeToSize(dummy_workspace_tensor.dtype()) + size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm,
dummy_barrier_tensor.shape().data[0] * typeToSize(dummy_barrier_tensor.dtype()); bool zero_centered_gamma, float eps
) {
void *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size); auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
auto intermediates_dtype = DType::kFloat32;
auto workspace_tensor = // empty tensor wrappers are okay just to get workspace size
TensorWrapper(workspace, dummy_workspace_tensor.shape(), dummy_workspace_tensor.dtype()); auto dz_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
auto x_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto gamma_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
auto xgrad_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto wgrad_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
auto barrier_tensor = // dummy tensor wrappers that will carry workspace size info later
TensorWrapper(reinterpret_cast<char *>(workspace) + dummy_workspace_tensor.shape().data[0], TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
dummy_barrier_tensor.shape(), dummy_barrier_tensor.dtype()); TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
// initialize dBeta information here -- layernorm will modify but RMSnorm will not
std::vector<size_t> dbeta_part_shape;
if (is_layer_norm) { if (is_layer_norm) {
auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype); auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32); auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream, rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
num_sm, workspace_tensor.data(), barrier_tensor.data()); wgrad_tensor.data(), dbeta_tensor.data(),
dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), nullptr,
num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
dbeta_part_shape = MakeShapeVector(dummy_dbeta_part_tensor.shape());
} else { } else {
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(), NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(), nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
barrier_tensor.data()); gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dummy_dgamma_part_tensor.data(), nullptr,
num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
dbeta_part_shape = std::vector<size_t>{0, 0};
} }
auto work_shape = MakeShapeVector(dummy_work_tensor.shape());
auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape());
auto dgamma_part_shape = MakeShapeVector(dummy_dgamma_part_tensor.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()),
std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()),
std::make_pair(dgamma_part_shape, dummy_dgamma_part_tensor.dtype()),
std::make_pair(dbeta_part_shape, dummy_dbeta_part_tensor.dtype()));
} }
void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps, void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size,
int sm_margin, void *input, DType in_dtype, void *weight, DType w_dtype, size_t wkspace_size, size_t barrier_size,
void *ograd, void *mu, void *rsigma, void *xgrad, void *wgrad, size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
void *dbeta, cudaStream_t stream) { bool zero_centered_gamma, float eps,
auto input_shape = std::vector<size_t>{n, hidden}; void *input, DType in_dtype, void *weight, DType w_dtype, void *ograd,
auto weight_shape = std::vector<size_t>{hidden}; void *workspace, DType wkspace_dtype, void *barrier, DType barrier_dtype,
auto intermediates_shape = std::vector<size_t>{n}; void *mu, void *rsigma, void *xgrad, void *wgrad, void *dbeta,
void *dgamma_part, DType dgamma_dtype,
void* dbeta_part, DType dbeta_dtype,
cudaStream_t stream) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
auto intermediates_dtype = DType::kFloat32; auto intermediates_dtype = DType::kFloat32;
auto is_layer_norm = (dbeta) ? true : false; auto is_layer_norm = (dbeta) ? true : false;
...@@ -374,62 +418,21 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, fl ...@@ -374,62 +418,21 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, fl
auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype); auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype);
auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype); auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype);
TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
size_t dbeta_part_size{};
auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
if (!is_layer_norm) {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
}
// The first call is to query the workspace auto workspace_shape = std::vector<size_t>{wkspace_size};
if (is_layer_norm) { auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype); auto barrier_shape = std::vector<size_t>{barrier_size};
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype); auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
auto dgamma_part_shape = std::vector<size_t>{dgamma_part_sizes[0], dgamma_part_sizes[1]};
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), auto dgamma_part_tensor = TensorWrapper(dgamma_part, dgamma_part_shape, dgamma_dtype);
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
wgrad_tensor.data(), dbeta_tensor.data(),
dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), stream,
num_sm, dummy_workspace_tensor.data(), dummy_barrier_tensor.data());
dbeta_part_size = dummy_dbeta_part_tensor.shape().data[0] *
dummy_dbeta_part_tensor.shape().data[1] *
typeToSize(dummy_dbeta_part_tensor.dtype());
} else {
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dummy_dgamma_part_tensor.data(), stream, num_sm,
dummy_workspace_tensor.data(), dummy_barrier_tensor.data());
}
size_t workspace_size =
dummy_workspace_tensor.shape().data[0] * typeToSize(dummy_workspace_tensor.dtype());
size_t barrier_size =
dummy_barrier_tensor.shape().data[0] * typeToSize(dummy_barrier_tensor.dtype());
size_t dgamma_part_size = dummy_dgamma_part_tensor.shape().data[0] *
dummy_dgamma_part_tensor.shape().data[1] *
typeToSize(dummy_dgamma_part_tensor.dtype());
auto [workspace, dgamma_part, dbeta_part, barrier] = WorkspaceManager::Instance().GetWorkspace(
workspace_size, dgamma_part_size, dbeta_part_size, barrier_size);
auto workspace_tensor =
TensorWrapper(workspace, dummy_workspace_tensor.shape(), dummy_workspace_tensor.dtype());
auto barrier_tensor =
TensorWrapper(barrier, dummy_barrier_tensor.shape(), dummy_barrier_tensor.dtype());
auto dgamma_part_tensor = TensorWrapper(dgamma_part, dummy_dgamma_part_tensor.shape(),
dummy_dgamma_part_tensor.dtype());
if (is_layer_norm) { if (is_layer_norm) {
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype); auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype); auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
auto dbeta_part_tensor = TensorWrapper(dbeta_part, dummy_dbeta_part_tensor.shape(), auto dbeta_part_shape = std::vector<size_t>{dbeta_part_sizes[0], dbeta_part_sizes[1]};
dummy_dbeta_part_tensor.dtype()); auto dbeta_part_tensor = TensorWrapper(dbeta_part, dbeta_part_shape, dbeta_dtype);
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
...@@ -437,6 +440,7 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, fl ...@@ -437,6 +440,7 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, fl
dbeta_part_tensor.data(), stream, num_sm, workspace_tensor.data(), dbeta_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
barrier_tensor.data()); barrier_tensor.data());
} else { } else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dgamma_part_tensor.data(), stream, num_sm, workspace_tensor.data(), dgamma_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
...@@ -456,22 +460,29 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque ...@@ -456,22 +460,29 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto *mu = buffers[7]; auto *mu = buffers[7];
auto *rsigma = buffers[8]; auto *rsigma = buffers[8];
auto *amax_out = buffers[9]; auto *amax_out = buffers[9];
auto *workspace = buffers[10];
auto *barrier = buffers[11];
assert(amax_out == amax); assert(amax_out == amax);
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n; auto batch_size = desc.batch_size;
auto hidden = desc.hidden; auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype; auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin; auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3; auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias,
stream); output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
} }
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -481,33 +492,48 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -481,33 +492,48 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto *output = buffers[3]; auto *output = buffers[3];
auto *mu = buffers[4]; auto *mu = buffers[4];
auto *rsigma = buffers[5]; auto *rsigma = buffers[5];
auto *workspace = buffers[6];
auto *barrier = buffers[7];
float *amax = nullptr; float *amax = nullptr;
float *scale = nullptr; float *scale = nullptr;
float *scale_inv = nullptr; float *scale_inv = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n; auto batch_size = desc.batch_size;
auto hidden = desc.hidden; auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype; auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto out_dtype = in_dtype; auto out_dtype = in_dtype;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin; auto sm_margin = desc.sm_margin;
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias,
stream); output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
} }
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n; auto batch_size = desc.batch_size;
auto hidden = desc.hidden; auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto *dgamma_part_sizes = desc.dgamma_part_sizes;
auto *dbeta_part_sizes = desc.dbeta_part_sizes;
auto in_dtype = desc.x_dtype; auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto dgamma_part_dtype = desc.dgamma_part_dtype;
auto dbeta_part_dtype = desc.dbeta_part_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin; auto sm_margin = desc.sm_margin;
...@@ -520,9 +546,16 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -520,9 +546,16 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto *xgrad = buffers[5]; auto *xgrad = buffers[5];
auto *wgrad = buffers[6]; auto *wgrad = buffers[6];
auto *dbeta = buffers[7]; auto *dbeta = buffers[7];
auto *workspace = buffers[8];
LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight, auto *barrier = buffers[9];
w_dtype, ograd, mu, rsigma, xgrad, wgrad, dbeta, stream); auto *dgamma_part = buffers[10];
auto *dbeta_part = buffers[11];
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
dgamma_part_sizes, dbeta_part_sizes, zero_centered_gamma, eps,
input, in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype,
barrier, barrier_dtype, mu, rsigma, xgrad, wgrad, dbeta,
dgamma_part, dgamma_part_dtype, dbeta_part, dbeta_part_dtype, stream);
} }
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -534,24 +567,31 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -534,24 +567,31 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto *output = buffers[5]; auto *output = buffers[5];
auto *rsigma = buffers[6]; auto *rsigma = buffers[6];
auto *amax_out = buffers[7]; auto *amax_out = buffers[7];
auto *workspace = buffers[8];
auto *barrier = buffers[9];
assert(amax_out == amax); assert(amax_out == amax);
void *bias = nullptr; void *bias = nullptr;
void *mu = nullptr; void *mu = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n; auto batch_size = desc.batch_size;
auto hidden = desc.hidden; auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype; auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin; auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3; auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias,
stream); output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
} }
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -559,6 +599,8 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz ...@@ -559,6 +599,8 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz
auto *weight = buffers[1]; auto *weight = buffers[1];
auto *output = buffers[2]; auto *output = buffers[2];
auto *rsigma = buffers[3]; auto *rsigma = buffers[3];
auto *workspace = buffers[4];
auto *barrier = buffers[5];
void *bias = nullptr; void *bias = nullptr;
void *mu = nullptr; void *mu = nullptr;
...@@ -567,18 +609,23 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz ...@@ -567,18 +609,23 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz
float *scale_inv = nullptr; float *scale_inv = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n; auto batch_size = desc.batch_size;
auto hidden = desc.hidden; auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype; auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin; auto sm_margin = desc.sm_margin;
auto out_dtype = in_dtype; auto out_dtype = in_dtype;
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias,
stream); output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
} }
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -588,21 +635,35 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si ...@@ -588,21 +635,35 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
auto *weight = buffers[3]; auto *weight = buffers[3];
auto *xgrad = buffers[4]; auto *xgrad = buffers[4];
auto *wgrad = buffers[5]; auto *wgrad = buffers[5];
auto *workspace = buffers[6];
auto *barrier = buffers[7];
auto *dgamma_part = buffers[8];
void *mu = nullptr;
void *dbeta = nullptr;
void *dbeta_part = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n; auto batch_size = desc.batch_size;
auto hidden = desc.hidden; auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto dgamma_part_sizes = desc.dgamma_part_sizes;
size_t dbeta_part_sizes[2] = {0, 0};
auto in_dtype = desc.x_dtype; auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto dgamma_part_dtype = desc.dgamma_part_dtype;
auto dbeta_part_dtype = DType::kByte;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
void *mu = nullptr;
void *dbeta = nullptr;
LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight, LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
w_dtype, ograd, mu, rsigma, xgrad, wgrad, dbeta, stream); dgamma_part_sizes, dbeta_part_sizes, zero_centered_gamma, eps,
input, in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype,
barrier, barrier_dtype, mu, rsigma, xgrad, wgrad, dbeta,
dgamma_part, dgamma_part_dtype, dbeta_part, dbeta_part_dtype, stream);
} }
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -645,7 +706,7 @@ void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaqu ...@@ -645,7 +706,7 @@ void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaqu
auto *output = buffers[1]; auto *output = buffers[1];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto shape = std::vector<size_t>{desc.batch, desc.heads, desc.q_seqlen, desc.k_seqlen}; auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype; auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, shape, dtype); auto input_tensor = TensorWrapper(input, shape, dtype);
...@@ -662,7 +723,7 @@ void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -662,7 +723,7 @@ void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaq
auto *dgrad = buffers[2]; auto *dgrad = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto shape = std::vector<size_t>{desc.batch, desc.heads, desc.q_seqlen, desc.k_seqlen}; auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype; auto dtype = desc.dtype;
auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype); auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
...@@ -680,8 +741,9 @@ void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char ...@@ -680,8 +741,9 @@ void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char
auto *output = buffers[2]; auto *output = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto io_shape = std::vector<size_t>{desc.batch, desc.heads, desc.q_seqlen, desc.k_seqlen}; auto io_shape = std::vector<size_t>{desc.batch_size, desc.head_dim,
auto mask_shape = std::vector<size_t>{desc.pad_batch, 1, desc.q_seqlen, desc.k_seqlen}; desc.q_seqlen, desc.k_seqlen};
auto mask_shape = std::vector<size_t>{desc.padding_size, 1, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype; auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, io_shape, dtype); auto input_tensor = TensorWrapper(input, io_shape, dtype);
...@@ -705,7 +767,7 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, ...@@ -705,7 +767,7 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
auto *output = buffers[1]; auto *output = buffers[1];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto attn_batch = desc.batch * desc.heads; auto attn_batch = desc.batch_size * desc.head_dim;
auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen}; auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype; auto dtype = desc.dtype;
...@@ -724,7 +786,7 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, ...@@ -724,7 +786,7 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
auto *dgrad = buffers[2]; auto *dgrad = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto attn_batch = desc.batch * desc.heads; auto attn_batch = desc.batch_size * desc.head_dim;
auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen}; auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype; auto dtype = desc.dtype;
...@@ -750,91 +812,225 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, ...@@ -750,91 +812,225 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
return backend; return backend;
} }
/*
NOTE: PrepareFusedAttnForwardAuxTensors unifies the auxiliary tensor pack logic from the fused
attention forward kernels in:
- common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812
- common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359
*/
void PrepareFusedAttnForwardAuxTensors(
NVTETensorPack *tensor_pack, const CustomCallFusedAttnDescriptor *desc,
NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
void *softmax_buf, void *rng_state_buf = nullptr, void *bias_buf = nullptr
) {
auto batch_size = desc->batch_size;
auto num_heads = desc->num_heads;
auto q_max_seqlen = desc->q_max_seqlen;
auto kv_max_seqlen = desc->kv_max_seqlen;
// all backends need softmax but expect different shapes/dtypes
// start with the max512 sequence length softmax shape/dtype and correct later
tensor_pack->size = 1;
Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
softmax_aux->data.dptr = softmax_buf;
softmax_aux->data.shape = std::vector<size_t>{
batch_size, num_heads, q_max_seqlen, kv_max_seqlen};
softmax_aux->data.dtype = desc->dtype;
// arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
tensor_pack->size = 2;
Tensor *rng_state_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[1]);
rng_state_aux->data.dptr = rng_state_buf;
rng_state_aux->data.shape = std::vector<size_t>{2};
rng_state_aux->data.dtype = DType::kInt64;
// correct softmax shape/dtype
softmax_aux->data.shape.at(3) = 1; // {B,H,Qs,Ks} -> {B,H,Qs,1}
softmax_aux->data.dtype = DType::kFloat32;
// include bias if enabled
if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) {
tensor_pack->size = 3;
Tensor *bias_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[2]);
bias_aux->data.dptr = bias_buf;
bias_aux->data.shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
bias_aux->data.dtype = desc->dtype;
}
}
}
/*
NOTE: Backward fused attention kernels accept auxiliary tensors as explicit function arguments
instead of an NVTETensorPack and nvte_fused_attn_bwd() API does all the logic for pulling the
necessary tensors out of the tensor pack for the active kernel. That means we can just dump
everything we got into the tensor pack and not worry about its sizing for the backward pass.
TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()?
*/
void PrepareFusedAttnBackwardAuxTensors(
NVTETensorPack* tensor_pack, const CustomCallFusedAttnDescriptor *desc,
NVTE_Fused_Attn_Backend backend, void* softmax_buf, void* rng_state_buf, void* bias_buf
) {
// Backward calls put everything into the tensor pack for every backend
// so we set dummy bias_type and backend choices here to follow the correct code path
auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
PrepareFusedAttnForwardAuxTensors(tensor_pack, desc, dummy_bias_type, dummy_backend,
softmax_buf, rng_state_buf, bias_buf);
// correct softmax shape for max512 sequence length kernel
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
Tensor* softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
softmax_aux->data.shape.at(3) = desc->kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks}
softmax_aux->data.dtype = desc->dtype;
}
}
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, max_seqlen, max_seqlen};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto cu_seqlens_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto o_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size * max_seqlen, num_heads, head_dim}, dtype);
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, num_heads, num_heads,
max_seqlen, max_seqlen, head_dim);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) { size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor = const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input // input buffers from XLA
void *qkv = buffers[0]; void *qkv = buffers[0];
void *bias = buffers[1]; void *bias = buffers[1];
void *cu_seqlens = buffers[2]; void *cu_seqlens = buffers[2];
void *seed = buffers[3]; void *seed = buffers[3];
// output // output buffers from XLA
void *output = buffers[4]; void *output = buffers[4];
void *softmax_aux = buffers[5]; void *softmax_aux = buffers[5];
void *rng_state = buffers[6]; void *rng_state = buffers[6];
void *workspace = buffers[7];
auto batch = descriptor.batch; // tensor sizes
auto num_head = descriptor.num_head; auto batch_size = descriptor.batch_size;
auto num_gqa_groups = descriptor.num_gqa_groups; auto max_seqlen = descriptor.q_max_seqlen;
auto q_max_seqlen = descriptor.q_max_seqlen; auto num_heads = descriptor.num_heads;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
auto dropout_probability = descriptor.dropout_probability; auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type; auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type; auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
NVTE_CHECK(num_head == num_gqa_groups,
"num_head should be equal to num_gqa_groups in the qkvpacked attention");
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim}; auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen}; auto bias_shape = std::vector<size_t>{1, num_heads, max_seqlen, max_seqlen};
// input tensors // input tensors
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
auto cu_seqlens_tensor = auto cu_seqlens_tensor = TensorWrapper(
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32); cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// output tensors // output tensors
auto o_tensor = auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype); auto o_tensor = TensorWrapper(
output, std::vector<size_t>{batch_size * max_seqlen, num_heads, head_dim}, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
// aux tensors // prep RNG state
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, num_heads, num_heads,
max_seqlen, max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, max_seqlen, max_seqlen, backend, stream);
auto backend = // auxiliary tensors (to be propagated to the backward pass later)
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
qkv_layout, bias_type, mask_type, dropout_probability, num_head,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
softmax_aux);
TensorWrapper query_workspace_tensor; // cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(), o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training, rng_state_tensor.data(), max_seqlen, descriptor.is_training,
descriptor.scaling_factor, dropout_probability, qkv_layout, descriptor.scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), stream); bias_type, mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]); pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
output_s->data.dptr = softmax_aux; size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto workspace_size = query_workspace_tensor.shape().data[0]; auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim};
auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size); auto output_shape = std::vector<size_t>{batch_size * max_seqlen, num_heads, head_dim};
auto workspace_tensor = auto bias_shape = std::vector<size_t>{1, num_heads, max_seqlen, max_seqlen};
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(), auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training, auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
descriptor.scaling_factor, dropout_probability, qkv_layout, // F16 doesn't use this tensor
bias_type, mask_type, workspace_tensor.data(), stream); auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
nvte_tensor_pack_destroy(&aux_output_tensors); auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto cu_seqlens_tensor = TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1},
DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
} }
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
...@@ -842,7 +1038,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -842,7 +1038,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
const CustomCallFusedAttnDescriptor &descriptor = const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input // input buffers from XLA
void *qkv = buffers[0]; void *qkv = buffers[0];
void *bias = buffers[1]; void *bias = buffers[1];
void *softmax_aux = buffers[2]; void *softmax_aux = buffers[2];
...@@ -851,82 +1047,107 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -851,82 +1047,107 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
void *doutput = buffers[5]; void *doutput = buffers[5];
void *cu_seqlens = buffers[6]; void *cu_seqlens = buffers[6];
// output // output buffers from XLA
void *dqkv = buffers[7]; void *dqkv = buffers[7];
void *dbias = buffers[8]; void *dbias = buffers[8];
void *workspace = buffers[9];
auto batch = descriptor.batch; // tensor sizes
auto num_head = descriptor.num_head; auto batch_size = descriptor.batch_size;
auto num_gqa_groups = descriptor.num_gqa_groups; auto max_seqlen = descriptor.q_max_seqlen;
auto q_max_seqlen = descriptor.q_max_seqlen; auto num_heads = descriptor.num_heads;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability; auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type; auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type; auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
NVTE_CHECK(num_head == num_gqa_groups,
"num_head should be equal to num_gqa_groups in the qkvpacked attention");
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim}; auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim};
auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}; auto output_shape = std::vector<size_t>{batch_size * max_seqlen, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen}; auto bias_shape = std::vector<size_t>{1, num_heads, max_seqlen, max_seqlen};
// input tensors
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype); auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype); auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto cu_seqlens_tensor = auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32); TensorWrapper(cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// TODO: needs to think about how to pass aux_output_tensors // auxiliary tensors (propagated from the forward pass)
NVTETensorPack aux_output_tensors; NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_input_tensors);
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
aux_output_tensors.size = 3; auto backend = nvte_get_fused_attn_backend(
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]); static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
output_s->data.dptr = softmax_aux; bias_type, mask_type, dropout_probability, num_heads, num_heads,
auto *rng_state_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[1]); max_seqlen, max_seqlen, head_dim);
rng_state_tensor->data.shape = std::vector<size_t>{2}; PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend,
rng_state_tensor->data.dtype = DType::kInt64; softmax_aux, rng_state, bias);
rng_state_tensor->data.dptr = rng_state;
auto *bias_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[2]);
bias_tensor->data = SimpleTensor(bias, bias_shape, dtype);
TensorWrapper query_workspace_tensor; // cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor, cu_seqlens_tensor.data(), max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), stream); workspace_tensor.data(), stream);
size_t workspace_size = query_workspace_tensor.shape().data[0]; nvte_tensor_pack_destroy(&aux_input_tensors);
auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size); }
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
s_tensor.data(), // not used for F16 size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
s_tensor.data(), // not used for F16 size_t num_heads, size_t num_gqa_groups, size_t head_dim,
&aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(), float scaling_factor, float dropout_probability,
cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
dropout_probability, qkv_layout, bias_type, mask_type, ) {
workspace_tensor.data(), stream); constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
nvte_tensor_pack_destroy(&aux_output_tensors); auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto kv_shape = std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// TODO(rewang): add bias for cross attn?
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto q_cu_seqlens_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
} }
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
...@@ -934,7 +1155,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -934,7 +1155,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
const CustomCallFusedAttnDescriptor &descriptor = const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input // input buffers from XLA
void *q = buffers[0]; void *q = buffers[0];
void *kv = buffers[1]; void *kv = buffers[1];
void *bias = buffers[2]; void *bias = buffers[2];
...@@ -942,83 +1163,115 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -942,83 +1163,115 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
void *kv_cu_seqlens = buffers[4]; void *kv_cu_seqlens = buffers[4];
void *seed = buffers[5]; void *seed = buffers[5];
// output // output buffers from XLA
void *output = buffers[6]; void *output = buffers[6];
void *softmax_aux = buffers[7]; void *softmax_aux = buffers[7];
void *rng_state = buffers[8]; void *rng_state = buffers[8];
void *workspace = buffers[9];
auto batch = descriptor.batch; // tensor sizes
auto num_head = descriptor.num_head; auto batch_size = descriptor.batch_size;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto num_heads = descriptor.num_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability; auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type; auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type; auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto dtype = descriptor.dtype; auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}; auto kv_shape = std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
// input tensors // input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
auto o_tensor = TensorWrapper(output, q_shape, dtype);
auto q_cu_seqlens_tensor = auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32); TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32); TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// output tensors
auto o_tensor =
TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype);
// aux tensors
// F16 doesn't use s_tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
// prep RNG state
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
auto backend = static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), bias_type, mask_type, dropout_probability, num_heads, num_gqa_groups,
qkv_layout, bias_type, mask_type, dropout_probability, num_head, q_max_seqlen, kv_max_seqlen, head_dim);
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
// auxiliary tensors (to be propagated to the backward pass later)
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
softmax_aux);
TensorWrapper query_workspace_tensor; // cuDNN workspace
auto workspace_tensor = TensorWrapper(
workspace, std::vector<size_t>{descriptor.wkspace_size}, descriptor.wkspace_dtype);
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), stream); workspace_tensor.data(), stream);
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]); nvte_tensor_pack_destroy(&aux_output_tensors);
output_s->data.dptr = softmax_aux; }
auto workspace_size = query_workspace_tensor.shape().data[0]; pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size); size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
auto workspace_tensor = size_t num_heads, size_t num_gqa_groups, size_t head_dim,
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
nvte_fused_attn_fwd_kvpacked( auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), auto kv_shape = std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim};
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto q_cu_seqlens_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
} }
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
...@@ -1026,7 +1279,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa ...@@ -1026,7 +1279,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
const CustomCallFusedAttnDescriptor &descriptor = const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input // input buffers from XLA
void *q = buffers[0]; void *q = buffers[0];
void *kv = buffers[1]; void *kv = buffers[1];
void *bias = buffers[2]; void *bias = buffers[2];
...@@ -1037,85 +1290,72 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa ...@@ -1037,85 +1290,72 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
void *q_cu_seqlens = buffers[7]; void *q_cu_seqlens = buffers[7];
void *kv_cu_seqlens = buffers[8]; void *kv_cu_seqlens = buffers[8];
// output // output buffers from XLA
void *dq = buffers[9]; void *dq = buffers[9];
void *dkv = buffers[10]; void *dkv = buffers[10];
void *dbias = buffers[11]; void *dbias = buffers[11];
void *workspace = buffers[12];
auto batch = descriptor.batch; // tensor sizes
auto num_head = descriptor.num_head; auto batch_size = descriptor.batch_size;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto num_heads = descriptor.num_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability; auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type; auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type; auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto dtype = descriptor.dtype; auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}; auto kv_shape = std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}; auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype); auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
auto dq_tensor = TensorWrapper(dq, q_shape, dtype); auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype); auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto q_cu_seqlens_tensor = TensorWrapper(
q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = TensorWrapper(
kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// auxiliary tensors (propagated from the forward pass)
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, num_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend,
softmax_aux, rng_state, bias);
auto q_cu_seqlens_tensor = // cuDNN workspace
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32); auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto kv_cu_seqlens_tensor = auto wkspace_dtype = descriptor.wkspace_dtype;
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32); auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
// TODO(rewang): need to think about how to pass aux_output_tensors
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
aux_output_tensors.size = 3;
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux;
auto *rng_state_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[1]);
rng_state_tensor->data.shape = std::vector<size_t>{2};
rng_state_tensor->data.dtype = DType::kInt64;
rng_state_tensor->data.dptr = rng_state;
auto *bias_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[2]);
bias_tensor->data = SimpleTensor(bias, bias_shape, dtype);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.scaling_factor, dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD, bias_type,
mask_type, query_workspace_tensor.data(), stream);
size_t workspace_size = query_workspace_tensor.shape().data[0];
auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16 s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16 s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.scaling_factor, dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD, bias_type, scaling_factor, dropout_probability, qkv_layout,
mask_type, workspace_tensor.data(), stream); bias_type, mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_input_tensors);
} }
} // namespace jax } // namespace jax
......
...@@ -52,68 +52,69 @@ struct CustomCallCommonDescriptor { ...@@ -52,68 +52,69 @@ struct CustomCallCommonDescriptor {
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype, pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype); DType out_dtype);
struct CustomCallGemmDescriptor {
size_t m;
size_t n;
size_t k;
DType A_dtype;
DType B_dtype;
DType D_dtype;
bool transa;
bool transb;
bool use_split_accumulator;
};
pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, DType A_dtype,
DType B_dtype, DType D_dtype, bool transa, bool transb,
bool use_split_accumulator);
struct CustomCallNormDescriptor { struct CustomCallNormDescriptor {
size_t n; size_t batch_size;
size_t hidden; size_t hidden_size;
size_t wkspace_size;
size_t barrier_size;
size_t *dgamma_part_sizes; // 2D tensor
size_t *dbeta_part_sizes; // 2D tensor
DType x_dtype; DType x_dtype;
DType w_dtype; DType w_dtype;
DType wkspace_dtype;
DType barrier_dtype;
DType dgamma_part_dtype;
DType dbeta_part_dtype;
bool zero_centered_gamma; bool zero_centered_gamma;
float eps; float eps;
int sm_margin; int sm_margin;
}; };
pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype, pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
DType x_dtype, DType w_dtype,
DType wkspace_dtype, DType barrier_dtype,
DType dgamma_part_dtype, DType dbeta_part_dtype,
bool zero_centered_gamma, float eps, int sm_margin); bool zero_centered_gamma, float eps, int sm_margin);
struct SoftmaxDescriptor { struct SoftmaxDescriptor {
size_t batch; size_t batch_size;
size_t pad_batch; size_t padding_size;
size_t heads; size_t head_dim;
size_t q_seqlen; size_t q_seqlen;
size_t k_seqlen; size_t k_seqlen;
DType dtype; DType dtype;
float scale_factor; float scale_factor;
}; };
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, size_t heads, pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
size_t q_seqlen, size_t k_seqlen, DType dtype, size_t head_dim, size_t q_seqlen, size_t k_seqlen,
float scale_factor); DType dtype, float scale_factor);
struct CustomCallFusedAttnDescriptor { struct CustomCallFusedAttnDescriptor {
size_t batch; size_t batch_size;
size_t num_head;
size_t num_gqa_groups;
size_t q_max_seqlen; size_t q_max_seqlen;
size_t kv_max_seqlen; size_t kv_max_seqlen;
size_t num_heads;
size_t num_gqa_groups;
size_t head_dim; size_t head_dim;
size_t wkspace_size;
float scaling_factor; float scaling_factor;
float dropout_probability; float dropout_probability;
NVTE_Bias_Type bias_type; NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type; NVTE_Mask_Type mask_type;
DType dtype; DType dtype;
DType wkspace_dtype;
bool is_training; bool is_training;
}; };
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch, size_t num_head, size_t num_gqa_groups, size_t q_max_seqlen, size_t kv_max_seqlen, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t wkspace_size,
NVTE_Mask_Type mask_type, DType dtype, bool is_training); float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
DType dtype, DType wkspace_dtype, bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -135,13 +136,21 @@ void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t ...@@ -135,13 +136,21 @@ void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t
void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); pybind11::tuple GetLayerNormForwardWorkspaceSizes(
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma, float eps
);
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm,
bool zero_centered_gamma, float eps
);
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
...@@ -172,15 +181,41 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, ...@@ -172,15 +181,41 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len); std::size_t opaque_len);
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
......
...@@ -28,66 +28,6 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q ...@@ -28,66 +28,6 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend, size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream); cudaStream_t stream);
class WorkspaceManager {
public:
static WorkspaceManager &Instance() {
static thread_local WorkspaceManager instance;
return instance;
}
WorkspaceManager() {}
~WorkspaceManager() { Clear_(); }
void *GetWorkspace(size_t size = 4194304) {
ReallocateIfNeed_(size);
return workspace_;
}
template <typename... Args>
inline auto GetWorkspace(Args... args) {
auto asks = std::array<size_t, sizeof...(Args)>{args...};
std::array<size_t, sizeof...(Args) + 1> offsets = {0};
std::array<void *, sizeof...(Args)> workspaces = {nullptr};
std::transform_inclusive_scan(
asks.cbegin(), asks.cend(), offsets.begin() + 1, std::plus<size_t>{},
[=](auto x) { return PadSize_(x); }, 0);
auto *workspace = GetWorkspace(offsets.back());
std::transform(offsets.cbegin(), offsets.cend() - 1, workspaces.begin(),
[workspace](auto x) { return static_cast<char *>(workspace) + x; });
return workspaces;
}
private:
void *workspace_ = nullptr;
size_t size_ = 0;
size_t PadSize_(size_t size) {
constexpr size_t alignment = 128;
return ((size + alignment - 1) / alignment) * alignment;
}
void Clear_() {
if (workspace_ != nullptr) {
NVTE_CHECK_CUDA(cudaFree(workspace_));
}
workspace_ = nullptr;
size_ = 0;
}
void Allocate_(size_t new_size) {
new_size = PadSize_(new_size);
NVTE_CHECK_CUDA(cudaMalloc(&workspace_, new_size));
size_ = new_size;
}
void ReallocateIfNeed_(size_t new_size) {
if (new_size > size_) {
Clear_();
Allocate_(new_size);
}
}
};
class cudaDevicePropertiesManager { class cudaDevicePropertiesManager {
public: public:
static cudaDevicePropertiesManager &Instance() { static cudaDevicePropertiesManager &Instance() {
......
...@@ -22,7 +22,7 @@ from ..dot import type_safe_dot_general ...@@ -22,7 +22,7 @@ from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage from ..fp8 import FP8Helper, FP8MetaPackage
from ..layernorm import canonicalize_layernorm_type from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot from ..layernorm import layernorm, layernorm_fp8_dot
from ..mlp import layernrom_geglu_fp8_mlp, geglu from ..mlp import layernorm_geglu_fp8_mlp, geglu
from ..softmax import is_softmax_kernel_available from ..softmax import is_softmax_kernel_available
from ..softmax import softmax, SoftmaxType from ..softmax import softmax, SoftmaxType
...@@ -886,7 +886,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -886,7 +886,7 @@ class LayerNormMLP(TransformerEngineBase):
if use_fused_ln_mlp: if use_fused_ln_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment assert self.axis == -1 # Only support axis = =-1 at this moment
out = layernrom_geglu_fp8_mlp(y, out = layernorm_geglu_fp8_mlp(y,
scale, scale,
ln_bias, [kernel_1, kernel_2], ln_bias, [kernel_1, kernel_2],
fp8_meta_package, fp8_meta_package,
......
...@@ -55,7 +55,7 @@ def _geglu_bwd_rule(ctx, g): ...@@ -55,7 +55,7 @@ def _geglu_bwd_rule(ctx, g):
_geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule) _geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule)
def layernrom_geglu_fp8_mlp(x: jnp.ndarray, def layernorm_geglu_fp8_mlp(x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
beta: jnp.ndarray, beta: jnp.ndarray,
kernels: List[jnp.ndarray], kernels: List[jnp.ndarray],
...@@ -86,25 +86,25 @@ def layernrom_geglu_fp8_mlp(x: jnp.ndarray, ...@@ -86,25 +86,25 @@ def layernrom_geglu_fp8_mlp(x: jnp.ndarray,
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'" "if layernorm_type is 'rmsnorm'"
output = _layernrom_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale, output = _layernorm_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale,
scale_inv, fwd_dtype, bwd_dtype, layernorm_type, scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
zero_centered_gamma, epsilon) zero_centered_gamma, epsilon)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13)) @partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13))
def _layernrom_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, def _layernorm_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray, kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray,
amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str,
zero_centered_gamma: bool, epsilon: float): zero_centered_gamma: bool, epsilon: float):
output, _ = _layernrom_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, output, _ = _layernorm_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax,
scale, scale_inv, fwd_dtype, bwd_dtype, scale, scale_inv, fwd_dtype, bwd_dtype,
layernorm_type, zero_centered_gamma, epsilon) layernorm_type, zero_centered_gamma, epsilon)
return output return output
def _layernrom_geglu_fp8_mlp_fwd_rule( def _layernorm_geglu_fp8_mlp_fwd_rule(
x, x,
gamma, gamma,
beta, beta,
...@@ -209,7 +209,7 @@ def _layernrom_geglu_fp8_mlp_fwd_rule( ...@@ -209,7 +209,7 @@ def _layernrom_geglu_fp8_mlp_fwd_rule(
return dot_2_output, ctx return dot_2_output, ctx
def _layernrom_geglu_fp8_mlp_bwd_rule( def _layernorm_geglu_fp8_mlp_bwd_rule(
fwd_dtype, # pylint: disable=unused-argument fwd_dtype, # pylint: disable=unused-argument
bwd_dtype, bwd_dtype,
layernorm_type, layernorm_type,
...@@ -307,5 +307,5 @@ def _layernrom_geglu_fp8_mlp_bwd_rule( ...@@ -307,5 +307,5 @@ def _layernrom_geglu_fp8_mlp_bwd_rule(
fp8_max, amax, scale, scale_inv fp8_max, amax, scale, scale_inv
_layernrom_geglu_fp8_mlp.defvjp(_layernrom_geglu_fp8_mlp_fwd_rule, _layernorm_geglu_fp8_mlp.defvjp(_layernorm_geglu_fp8_mlp_fwd_rule,
_layernrom_geglu_fp8_mlp_bwd_rule) _layernorm_geglu_fp8_mlp_bwd_rule)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment