Unverified Commit 4077ccc1 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] Custom Op Workspace Tensors from XLA Buffers (#532)



* Removed cudaMalloc/WorkspaceManager in JAX csrc. JAX custom ops now request buffers from XLA for their workspace tensors.
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* removed unused GEMM C++ API in TE-JAX
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed typo in layernorm_geglu_fp8_mlp and removed unnecessary shape reductions in primitives
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed import order for linting
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed custom op errors due to incorrect static arg nums in JAX jit
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* shifted cudnnSetStream further down the kernel to avoid error when executing dummy kernel call with nullptr stream
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed linting errors for blank lines
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent bd7fd0a6
......@@ -25,3 +25,5 @@ tests/cpp/build/
docs/_build
.ipynb_checkpoints
docs/doxygen
*.log
CMakeFiles/CMakeSystem.cmake
\ No newline at end of file
......@@ -20,7 +20,7 @@ from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quanti
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import layernrom_geglu_fp8_mlp
from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp
GEMM_CASES = [
(256, 256, 512),
......@@ -196,7 +196,7 @@ class TestFP8Dot:
# out = (x * y) * z
fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.mean(layernrom_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm"))
return jnp.mean(layernorm_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm"))
def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function."""
......
......@@ -59,8 +59,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
......@@ -248,6 +246,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
return;
}
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
// Build variant pack
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{Q, devPtrQ},
......@@ -300,8 +302,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
......@@ -519,6 +519,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
return;
}
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
// build variant pack
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{q, devPtrQ},
......
......@@ -642,8 +642,6 @@ void fused_attn_max_512_fwd_impl(
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *workspace, size_t *workspace_size,
cudnnDataType_t tensorType, cudaStream_t stream, cudnnHandle_t handle) {
try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
FADescriptor descriptor{b, h,
s_q, s_kv,
d, scaling_factor,
......@@ -754,6 +752,10 @@ void fused_attn_max_512_fwd_impl(
return;
}
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
// Prepare actual seqlen
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
......@@ -845,9 +847,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
size_t *workspace_size, cudnnDataType_t tensorType,
cudaStream_t stream, cudnnHandle_t handle) {
try {
// Create cudnn handle
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
FADescriptor descriptor{
b, h, s_q, s_kv, d, scaling_factor, true, dropout_probability,
layout, bias_type, mask_type, tensorType, false};
......@@ -1194,6 +1193,10 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
return;
}
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
......
......@@ -1007,8 +1007,6 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
cudaStream_t stream,
cudnnHandle_t handle_) {
try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));
FADescriptor descriptor{
b, h, s_q, s_kv, d,
attnScale, isTraining, dropoutProbability, layout,
......@@ -1212,6 +1210,10 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
return;
}
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));
int32_t* qkv_ragged_offset = reinterpret_cast<int32_t*>(
reinterpret_cast<int8_t*>(workspace_ptr) + wkspace_size);
int32_t* o_ragged_offset = reinterpret_cast<int32_t*>(
......@@ -1324,8 +1326,6 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
cudaStream_t stream,
cudnnHandle_t handle_) {
try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));
FADescriptor descriptor{
b, h, s_q, s_kv, d,
attnScale, false, dropoutProbability, layout,
......@@ -1745,6 +1745,10 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
return;
}
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));
int32_t* qkv_ragged_offset = reinterpret_cast<int32_t*>(
reinterpret_cast<int8_t*>(workspace_ptr) + wkspace_size);
int32_t* o_ragged_offset = reinterpret_cast<int32_t*>(
......
......@@ -159,14 +159,6 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const bool fp8_out = is_fp8_dtype(otype);
const auto ctype = layer_norm::DType::kFloat32;
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*mu, "mu");
CheckOutputTensor(*rsigma, "rsigma");
NVTE_CHECK(x.data.shape.size() == 2);
const size_t rows = x.data.shape[0];
......@@ -227,6 +219,16 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
return;
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*mu, "mu");
CheckOutputTensor(*rsigma, "rsigma");
if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
......@@ -273,15 +275,6 @@ void layernorm_bwd(const Tensor& dz,
auto otype = wtype;
auto ctype = DType::kFloat32;
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
CheckOutputTensor(*dbeta, "dbeta");
NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(mu.data.dtype == ctype);
NVTE_CHECK(rsigma.data.dtype == ctype);
......@@ -354,6 +347,16 @@ void layernorm_bwd(const Tensor& dz,
return;
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
CheckOutputTensor(*dbeta, "dbeta");
if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
......
......@@ -113,12 +113,6 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
const bool fp8_out = is_fp8_dtype(otype);
auto ctype = DType::kFloat32;
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*rsigma, "rsigma");
NVTE_CHECK(x.data.shape.size() == 2);
const size_t rows = x.data.shape[0];
......@@ -172,6 +166,15 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
return;
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*rsigma, "rsigma");
if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
......@@ -204,13 +207,6 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
auto otype = wtype;
auto ctype = DType::kFloat32;
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(rsigma.data.dtype == ctype);
......@@ -268,6 +264,14 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
return;
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
......
......@@ -10,13 +10,6 @@ import operator
import os
import warnings
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend
import numpy as np
import jax.numpy as jnp
from jax.lib import xla_client
......@@ -28,6 +21,13 @@ from jax.sharding import PartitionSpec, NamedSharding
from jax._src.interpreters import batching
from jax._src import dispatch
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend
from .sharding import all_reduce_max_along_all_axes_except_PP
from .sharding import all_reduce_sum_along_dp_fsdp
from .sharding import get_all_mesh_axes, num_of_devices
......@@ -58,6 +58,7 @@ def te_dtype_to_jax_dtype(te_dtype):
TEDType.kInt64: jnp.int64,
TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
TEDType.kFloat8E5M2: jnp.float8_e5m2,
TEDType.kByte: jnp.uint8
}
if te_dtype not in converter:
......@@ -94,6 +95,7 @@ def jax_dtype_to_te_dtype(jax_dtype):
jnp.int64.dtype: TEDType.kInt64,
jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3,
jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2,
jnp.uint8.dtype: TEDType.kByte,
}
if jax_dtype not in converter:
......@@ -124,7 +126,7 @@ def _check_valid_batch_dims(bdims):
class BasePrimitive(metaclass=ABCMeta):
"""
jax premitive
jax primitive
"""
@staticmethod
......@@ -135,6 +137,13 @@ class BasePrimitive(metaclass=ABCMeta):
"""
return NotImplemented
@classmethod
def outer_abstract(cls, *args, **kwargs):
"""
optional abstract wrapper to eliminate workspace tensors
"""
return cls.abstract(*args, **kwargs)
@staticmethod
@abstractmethod
def lowering():
......@@ -196,7 +205,7 @@ def register_primitive(cls):
dispatch.prim_requires_devices_during_lowering.add(outer_p)
outer_p.multiple_results = cls.multiple_results
outer_p.def_impl(cls.impl)
outer_p.def_abstract_eval(cls.abstract)
outer_p.def_abstract_eval(cls.outer_abstract)
batching.primitive_batchers[outer_p] = cls.batcher
outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands,
......@@ -287,9 +296,9 @@ class LayerNormFwdPrimitive(BasePrimitive):
outer_primitive = None
@staticmethod
def abstract(x_aval, gamma_aval, beta_aval, **kwargs): # pylint: disable=unused-argument
def abstract(x_aval, gamma_aval, beta_aval, **kwargs):
"""
LayerNorm fwd abstract
LayerNorm fwd inner primitive abstract
"""
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
......@@ -303,6 +312,28 @@ class LayerNormFwdPrimitive(BasePrimitive):
hidden_size = gamma_aval.size
assert x_aval.size % hidden_size == 0
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch size
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16)
True, kwargs['zero_centered_gamma'], kwargs['epsilon']
)
wkspace_aval = out_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
barrier_aval = out_aval.update(shape=barrier_info[0],
dtype=te_dtype_to_jax_dtype(barrier_info[1]))
return out_aval, mu_aval, rsigma_aval, wkspace_aval, barrier_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
LayerNorm fwd outer primitive abstract
"""
out_aval, mu_aval, rsigma_aval, _, _ = \
LayerNormFwdPrimitive.abstract(*args, **kwargs)
return out_aval, mu_aval, rsigma_aval
@staticmethod
......@@ -333,10 +364,14 @@ class LayerNormFwdPrimitive(BasePrimitive):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval, barrier_aval = ctx.avals_out[-2:]
out_types = [
ir.RankedTensorType.get(out_shape, output_type),
ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype))
]
operands = [x, gamma, beta]
operand_shapes = [x_shape, g_shape, b_shape]
......@@ -347,8 +382,16 @@ class LayerNormFwdPrimitive(BasePrimitive):
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
zero_centered_gamma,
epsilon,
sm_margin,
......@@ -364,7 +407,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
to describe implementation
"""
assert LayerNormFwdPrimitive.inner_primitive is not None
out, mu, rsigma = LayerNormFwdPrimitive.inner_primitive.bind(
out, mu, rsigma, _, _ = LayerNormFwdPrimitive.inner_primitive.bind(
x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon)
return out, mu, rsigma
......@@ -449,9 +492,9 @@ class LayerNormBwdPrimitive(BasePrimitive):
outer_primitive = None
@staticmethod
def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs): # pylint: disable=unused-argument
def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs):
"""
Layernorm bwd abstract
Layernorm bwd inner primitive abstract
"""
w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype)
......@@ -464,6 +507,34 @@ class LayerNormBwdPrimitive(BasePrimitive):
dx_aval = core.raise_to_shaped(dz_aval)
dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval)
wkspace_info, barrier_info, dgamma_part_info, dbeta_part_info = \
transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
True, kwargs['zero_centered_gamma'], kwargs['epsilon']
)
wkspace_aval = dx_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
barrier_aval = dx_aval.update(shape=barrier_info[0],
dtype=te_dtype_to_jax_dtype(barrier_info[1]))
dgamma_part_aval = dgamma_aval.update(shape=dgamma_part_info[0],
dtype=te_dtype_to_jax_dtype(dgamma_part_info[1]))
dbeta_part_aval = dbeta_aval.update(shape=dbeta_part_info[0],
dtype=te_dtype_to_jax_dtype(dbeta_part_info[1]))
return dx_aval, dgamma_aval, dbeta_aval, wkspace_aval, barrier_aval, \
dgamma_part_aval, dbeta_part_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
LayerNorm bwd outer primitive abstract
"""
dx_aval, dgamma_aval, dbeta_aval, _, _, _, _ = \
LayerNormBwdPrimitive.abstract(*args, **kwargs)
return dx_aval, dgamma_aval, dbeta_aval
@staticmethod
......@@ -488,22 +559,32 @@ class LayerNormBwdPrimitive(BasePrimitive):
hidden_size = reduce(operator.mul, g_shape)
batch_size = reduce(operator.mul, x_shape) // hidden_size
out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(g_shape, g_type.element_type),
ir.RankedTensorType.get(b_shape, b_type.element_type),
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
operands = [dz, mu, rsigma, x, gamma]
operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
wkspace_aval, barrier_aval, dgamma_part_aval, dbeta_part_aval = ctx.avals_out[-4:]
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
dgamma_part_aval.size,
dbeta_part_aval.size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
jax_dtype_to_te_dtype(dgamma_part_aval.dtype),
jax_dtype_to_te_dtype(dbeta_part_aval.dtype),
zero_centered_gamma,
epsilon,
sm_margin,
......@@ -516,7 +597,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
@staticmethod
def impl(dz, x, mu, rsigma, gamma, zero_centered_gamma, epsilon):
assert LayerNormBwdPrimitive.inner_primitive is not None
dx, dgamma, dbeta = LayerNormBwdPrimitive.inner_primitive.bind(
dx, dgamma, dbeta, _, _, _, _ = LayerNormBwdPrimitive.inner_primitive.bind(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon)
return dx, dgamma, dbeta
......@@ -609,9 +690,9 @@ class RmsNormFwdPrimitive(BasePrimitive):
outer_primitive = None
@staticmethod
def abstract(x_aval, gamma_aval, **kwargs): # pylint: disable=unused-argument
def abstract(x_aval, gamma_aval, **kwargs):
"""
RMSNorm fwd abstract
RMSNorm fwd inner primitive abstract
"""
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
......@@ -624,6 +705,27 @@ class RmsNormFwdPrimitive(BasePrimitive):
hidden_size = gamma_aval.size
assert x_aval.size % hidden_size == 0
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch size
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16)
False, False, kwargs['epsilon']
)
wkspace_aval = out_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
barrier_aval = out_aval.update(shape=barrier_info[0],
dtype=te_dtype_to_jax_dtype(barrier_info[1]))
return out_aval, rsigma_aval, wkspace_aval, barrier_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
RMSNorm fwd outer primitive abstract
"""
out_aval, rsigma_aval, _, _ = RmsNormFwdPrimitive.abstract(*args, **kwargs)
return out_aval, rsigma_aval
@staticmethod
......@@ -643,9 +745,13 @@ class RmsNormFwdPrimitive(BasePrimitive):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval, barrier_aval = ctx.avals_out[-2:]
out_types = [
ir.RankedTensorType.get(out_shape, x_type.element_type),
ir.RankedTensorType.get(batch_shape, rsigma_element_type),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype))
]
operands = [x, gamma]
operand_shapes = [x_shape, g_shape]
......@@ -656,8 +762,16 @@ class RmsNormFwdPrimitive(BasePrimitive):
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
......@@ -673,7 +787,7 @@ class RmsNormFwdPrimitive(BasePrimitive):
to describe implementation
"""
assert RmsNormFwdPrimitive.inner_primitive is not None
out, rsigma = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon)
out, rsigma, _, _ = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon)
return out, rsigma
@staticmethod
......@@ -744,15 +858,9 @@ class RmsNormBwdPrimitive(BasePrimitive):
outer_primitive = None
@staticmethod
def abstract(
dz_aval,
x_aval,
rsigma_aval,
gamma_aval,
**kwargs # pylint: disable=unused-argument
):
def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs):
"""
RMSNorm bwd abstract
RMSNorm bwd inner primitive abstract
"""
w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype)
......@@ -764,6 +872,30 @@ class RmsNormBwdPrimitive(BasePrimitive):
dx_aval = core.raise_to_shaped(dz_aval)
dgamma_aval = core.raise_to_shaped(gamma_aval)
wkspace_info, barrier_info, dgamma_part_info, _ = \
transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
False, False, kwargs['epsilon']
)
wkspace_aval = dx_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
barrier_aval = dx_aval.update(shape=barrier_info[0],
dtype=te_dtype_to_jax_dtype(barrier_info[1]))
dgamma_part_aval = dgamma_aval.update(shape=dgamma_part_info[0],
dtype=te_dtype_to_jax_dtype(dgamma_part_info[1]))
return dx_aval, dgamma_aval, wkspace_aval, barrier_aval, dgamma_part_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
RMSNorm bwd outer primitive abstract
"""
dx_aval, dgamma_aval, _, _, _ = RmsNormBwdPrimitive.abstract(*args, **kwargs)
return dx_aval, dgamma_aval
@staticmethod
......@@ -782,9 +914,15 @@ class RmsNormBwdPrimitive(BasePrimitive):
hidden_size = reduce(operator.mul, g_shape)
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval, barrier_aval, dgamma_part_aval = ctx.avals_out[-3:]
out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(g_shape, g_type.element_type),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)),
ir.RankedTensorType.get(dgamma_part_aval.shape,
jax_dtype_to_ir_dtype(dgamma_part_aval.dtype))
]
operands = [dz, rsigma, x, gamma]
operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape]
......@@ -795,8 +933,16 @@ class RmsNormBwdPrimitive(BasePrimitive):
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
dgamma_part_aval.size,
0, # no dbeta_part for RMSnorm
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
jax_dtype_to_te_dtype(dgamma_part_aval.dtype),
TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
......@@ -809,7 +955,8 @@ class RmsNormBwdPrimitive(BasePrimitive):
@staticmethod
def impl(dz, x, rsigma, gamma, epsilon):
assert RmsNormBwdPrimitive.inner_primitive is not None
dx, dgamma = RmsNormBwdPrimitive.inner_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)
dx, dgamma, _, _, _ = \
RmsNormBwdPrimitive.inner_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)
return dx, dgamma
@staticmethod
......@@ -1721,40 +1868,60 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
def abstract(qkv_aval, bias_aval, seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Self fused attention fwd abstract
Self fused attention fwd inner primitive abstract
"""
# outer_primitve is seqlen, inner_primitive is cu_seqlen
del seqlen_or_cu_seqlen_aval, scaling_factor, is_training
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
del seqlen_or_cu_seqlen_aval
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
*batch_shape, max_seqlen, nqkv, num_head, head_dim = qkv_aval.shape
*batch_shape, max_seqlen, nqkv, num_heads, head_dim = qkv_aval.shape
assert nqkv == 3
assert qkv_aval.dtype == bias_aval.dtype
output_shape = (*batch_shape, max_seqlen, num_head, head_dim)
output_dtype = qkv_dtype
output_shape = (*batch_shape, max_seqlen, num_heads, head_dim)
out_aval = qkv_aval.update(shape=output_shape, dtype=qkv_dtype)
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type,
attn_mask_type, dropout_probability, num_head, num_head,
attn_mask_type, dropout_probability, num_heads, num_heads,
max_seqlen, max_seqlen, head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_aux_shape = (*batch_shape, num_head, max_seqlen, max_seqlen)
softmax_shape = (*batch_shape, num_heads, max_seqlen, max_seqlen)
softmax_dtype = qkv_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_aux_shape = (*batch_shape, num_head, max_seqlen, 1)
softmax_shape = (*batch_shape, num_heads, max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f'Unsupported {backend=}')
softmax_aux_aval = qkv_aval.update(shape=softmax_shape, dtype=softmax_dtype)
# JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
# 32-bit unsigned int to get the buffer size we need in the C++ kernel
checker = _FusedAttnRNGStateChecker()
seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
assert seed_dtype == checker.rng_state_dtype
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_dtype = seed_dtype
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend
batch_size = reduce(operator.mul, batch_shape)
wkspace_info = transformer_engine_jax.get_self_fused_attn_fwd_workspace_sizes(
batch_size, max_seqlen, num_heads, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training
)
wkspace_aval = qkv_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
out_aval = qkv_aval.update(shape=output_shape, dtype=output_dtype)
softmax_aux_aval = qkv_aval.update(shape=softmax_aux_shape, dtype=softmax_dtype)
rng_state_aval = qkv_aval.update(shape=rng_state_shape, dtype=rng_state_dtype)
return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Self fused attention fwd outer primitive abstract
"""
out_aval, softmax_aux_aval, rng_state_aval, _ = \
SelfFusedAttnFwdPrimitive.abstract(*args, **kwargs)
return out_aval, softmax_aux_aval, rng_state_aval
@staticmethod
......@@ -1763,23 +1930,25 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
"""
Self fused attention fwd lowering rules
"""
qkv_aval, _, _, _ = ctx.avals_in
*batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape
batch = reduce(operator.mul, batch_shape)
operands = [qkv, bias, cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
qkv_aval = ctx.avals_in[0]
*batch_shape, max_seqlen, _, num_heads, head_dim = qkv_aval.shape
batch_size = reduce(operator.mul, batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor,
dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
batch_size, max_seqlen, max_seqlen, num_heads, num_heads, head_dim, wkspace_aval.size,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
......@@ -1792,7 +1961,7 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
cu_seqlen = generate_cu_seqlen(seqlen)
output, softmax_aux, rng_state = SelfFusedAttnFwdPrimitive.inner_primitive.bind(
output, softmax_aux, rng_state, _ = SelfFusedAttnFwdPrimitive.inner_primitive.bind(
qkv,
bias,
cu_seqlen,
......@@ -1897,16 +2066,35 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
"""
Self fused attention bwd abstract
"""
del softmax_aux_aval, rng_state_aval
# outer_primitve is seqlen, inner_primitive is cu_seqlen
del seqlen_or_cu_seqlen_aval, attn_bias_type, attn_mask_type
del scaling_factor, dropout_probability, is_training
del softmax_aux_aval, rng_state_aval, seqlen_or_cu_seqlen_aval
assert qkv_aval.dtype == bias_aval.dtype == output_aval.dtype == doutput_aval.dtype
*batch_shape, max_seqlen, nqkv, num_heads, head_dim = qkv_aval.shape
assert nqkv == 3
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
assert qkv_aval.dtype == bias_aval.dtype == output_aval.dtype == doutput_aval.dtype
batch_size = reduce(operator.mul, batch_shape)
wkspace_shape, wkspace_dtype = \
transformer_engine_jax.get_self_fused_attn_bwd_workspace_sizes(
batch_size, max_seqlen, num_heads, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training
)
dqkv_aval = qkv_aval.update(shape=qkv_aval.shape, dtype=qkv_dtype)
dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
wkspace_aval = qkv_aval.update(shape=wkspace_shape,
dtype=te_dtype_to_jax_dtype(wkspace_dtype))
return dqkv_aval, dbias_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Self fused attention bwd outer primitive abstract
"""
dqkv_aval, dbias_aval, _ = SelfFusedAttnBwdPrimitive.abstract(*args, **kwargs)
return dqkv_aval, dbias_aval
@staticmethod
......@@ -1915,24 +2103,25 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
"""
Self fused attention bwd lowering rules
"""
qkv_aval, _, _, _, _, _, _ = ctx.avals_in
*batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape
batch = reduce(operator.mul, batch_shape)
operands = [qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
qkv_aval = ctx.avals_in[0]
*batch_shape, max_seqlen, _, num_heads, head_dim = qkv_aval.shape
batch_size = reduce(operator.mul, batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor,
dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
batch_size, max_seqlen, max_seqlen, num_heads, num_heads, head_dim, wkspace_aval.size,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
......@@ -1945,7 +2134,7 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
cu_seqlen = generate_cu_seqlen(seqlen)
dqkv, dbias = SelfFusedAttnBwdPrimitive.inner_primitive.bind(
dqkv, dbias, _ = SelfFusedAttnBwdPrimitive.inner_primitive.bind(
qkv,
bias,
softmax_aux,
......@@ -2067,50 +2256,62 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
"""
Cross fused attention fwd abstract
"""
# outer_primitve is seqlen, inner_primitive is cu_seqlen
del scaling_factor, is_training
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
*q_batch_shape, q_max_seqlen, q_num_head, q_head_dim = q_aval.shape
kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
*kv_batch_shape, kv_max_seqlen, nkv, kv_num_head, kv_head_dim = kv_aval.shape
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
assert q_dtype == kv_dtype == bias_dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, num_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = kv_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert nkv == 2
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
output_shape = q_aval.shape
output_dtype = q_dtype
out_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
attn_bias_type, attn_mask_type, dropout_probability, q_num_head,
kv_num_head, q_max_seqlen, kv_max_seqlen,
attn_bias_type, attn_mask_type, dropout_probability,
num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen)
softmax_aux_dtype = q_dtype
softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, kv_max_seqlen)
softmax_dtype = q_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, 1)
softmax_aux_dtype = dtypes.canonicalize_dtype(jnp.float32)
softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f'Unsupported {backend=}')
softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)
# JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
# 32-bit unsigned int to get the buffer size we need in the C++ kernel
checker = _FusedAttnRNGStateChecker()
seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
assert seed_dtype == checker.rng_state_dtype
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_dtype = seed_dtype
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend
batch_size = reduce(operator.mul, q_batch_shape)
wkspace_info = transformer_engine_jax.get_cross_fused_attn_fwd_workspace_sizes(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training
)
wkspace_aval = q_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
out_aval = q_aval.update(shape=output_shape, dtype=output_dtype)
softmax_aux_aval = q_aval.update(shape=softmax_aux_shape, dtype=softmax_aux_dtype)
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=rng_state_dtype)
return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Cross fused attention fwd outer primitive abstract
"""
out_aval, softmax_aux_aval, rng_state_aval, _ = \
CrossFusedAttnFwdPrimitive.abstract(*args, **kwargs)
return out_aval, softmax_aux_aval, rng_state_aval
@staticmethod
......@@ -2119,25 +2320,27 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
"""
Cross fused attention fwd lowering rules
"""
q_aval, kv_aval, *_ = ctx.avals_in
assert q_aval.dtype == kv_aval.dtype
*batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
batch = reduce(operator.mul, batch_shape)
kv_max_seqlen, kv_num_head = kv_aval.shape[-4], kv_aval.shape[-2]
operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, kv_aval, *_ = ctx.avals_in
*batch_shape, q_max_seqlen, num_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, _, num_gqa_groups, _ = kv_aval.shape
batch_size = reduce(operator.mul, batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, kv_num_head, q_max_seqlen, kv_max_seqlen, head_dim,
batch_size, q_max_seqlen, kv_max_seqlen,
num_heads, num_gqa_groups, head_dim, wkspace_aval.size,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(CrossFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
......@@ -2151,7 +2354,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
output, softmax_aux, rng_state = CrossFusedAttnFwdPrimitive.inner_primitive.bind(
output, softmax_aux, rng_state, _ = CrossFusedAttnFwdPrimitive.inner_primitive.bind(
q,
kv,
bias,
......@@ -2266,7 +2469,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
Cross fused attention bwd abstract
"""
del softmax_aux_aval, rng_state_aval, output_aval
del attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
......@@ -2274,9 +2477,35 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
assert q_dtype == kv_dtype == bias_dtype == doutput_dtype
assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, num_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = kv_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert nkv == 2
batch_size = reduce(operator.mul, q_batch_shape)
wkspace_shape, wkspace_dtype = \
transformer_engine_jax.get_cross_fused_attn_bwd_workspace_sizes(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training
)
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
dkv_aval = kv_aval.update(shape=kv_aval.shape, dtype=kv_dtype)
dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
wkspace_aval = q_aval.update(shape=wkspace_shape,
dtype=te_dtype_to_jax_dtype(wkspace_dtype))
return dq_aval, dkv_aval, dbias_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Cross fused attention fwd outer primitive abstract
"""
dq_aval, dkv_aval, dbias_aval, _ = \
CrossFusedAttnBwdPrimitive.abstract(*args, **kwargs)
return dq_aval, dkv_aval, dbias_aval
@staticmethod
......@@ -2286,13 +2515,6 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
"""
Cross fused attention bwd lowering rules
"""
q_aval, kv_aval, *_ = ctx.avals_in
assert q_aval.dtype == kv_aval.dtype
*batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
batch = reduce(operator.mul, batch_shape)
kv_max_seqlen, kv_num_head = kv_aval.shape[-4], kv_aval.shape[-2]
operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
......@@ -2302,12 +2524,19 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
# the dropout elements are encoded in the forward auxiliary tensor
# so seed is not needed in backward
q_aval, kv_aval, *_ = ctx.avals_in
*batch_shape, q_max_seqlen, num_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, _, num_gqa_groups, _ = kv_aval.shape
batch_size = reduce(operator.mul, batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, kv_num_head, q_max_seqlen, kv_max_seqlen, head_dim,
batch_size, q_max_seqlen, kv_max_seqlen,
num_heads, num_gqa_groups, head_dim, wkspace_aval.size,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(CrossFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
......@@ -2321,7 +2550,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
dq, dkv, dbias = CrossFusedAttnBwdPrimitive.inner_primitive.bind(
dq, dkv, dbias, _ = CrossFusedAttnBwdPrimitive.inner_primitive.bind(
q,
kv,
bias,
......@@ -3143,9 +3372,8 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
def abstract(x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
zero_centered_gamma, epsilon):
"""
LayerNorm fwd (fp8 out) abstract
LayerNorm fwd (fp8 out) inner primitive abstract
"""
del zero_centered_gamma, epsilon
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
......@@ -3157,10 +3385,32 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
assert gamma_aval.size == beta_aval.size
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # in type
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight type
jax_dtype_to_te_dtype(out_dtype),
True, zero_centered_gamma, epsilon
)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
wkspace_aval = x_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
barrier_aval = x_aval.update(shape=barrier_info[0],
dtype=te_dtype_to_jax_dtype(barrier_info[1]))
return out_aval, mu_aval, rsigma_aval, updated_amax_aval, wkspace_aval, barrier_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
LayerNorm fwd (fp8 out) outer primitive abstract
"""
out_aval, mu_aval, rsigma_aval, updated_amax_aval, _, _ = \
LayerNormFwdFp8Primitive.abstract(*args, **kwargs)
return out_aval, mu_aval, rsigma_aval, updated_amax_aval
@staticmethod
......@@ -3204,11 +3454,15 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval, barrier_aval = ctx.avals_out[-2:]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype))
]
operands = [x, gamma, beta, amax, scale, scale_inv]
operand_shapes = [
......@@ -3221,8 +3475,16 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
zero_centered_gamma,
epsilon,
sm_margin,
......@@ -3242,7 +3504,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
to describe implementation
"""
assert LayerNormFwdFp8Primitive.inner_primitive is not None
out, mu, rsigma, updated_amax = LayerNormFwdFp8Primitive.inner_primitive.bind(
out, mu, rsigma, updated_amax, _, _ = LayerNormFwdFp8Primitive.inner_primitive.bind(
x,
gamma,
beta,
......@@ -3359,9 +3621,8 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
@staticmethod
def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtype, epsilon):
"""
RMSNorm fwd (fp8 out) abstract
RMSNorm fwd (fp8 out) inner primitive abstract
"""
del epsilon
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
......@@ -3374,10 +3635,31 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
rsigama_dtype = jnp.float32
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch_size
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(out_dtype), # out te_dtype
False, False, epsilon
)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
amax_aval = out_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
wkspace_aval = x_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
barrier_aval = x_aval.update(shape=barrier_info[0],
dtype=te_dtype_to_jax_dtype(barrier_info[1]))
return out_aval, rsigma_aval, amax_aval, wkspace_aval, barrier_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
RMSNorm fwd (fp8 out) outer primitive abstract
"""
out_aval, rsigma_aval, amax_aval, _, _ = RmsNormFwdFp8Primitive.abstract(*args, **kwargs)
return out_aval, rsigma_aval, amax_aval
@staticmethod
......@@ -3414,10 +3696,14 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval, barrier_aval = ctx.avals_out[-2:]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype))
]
operands = [x, gamma, amax, scale, scale_inv]
operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
......@@ -3428,8 +3714,16 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
......@@ -3449,7 +3743,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
to describe implementation
"""
assert RmsNormFwdFp8Primitive.inner_primitive is not None
out, rsigma, amax = RmsNormFwdFp8Primitive.inner_primitive.bind(x,
out, rsigma, amax, _, _ = RmsNormFwdFp8Primitive.inner_primitive.bind(x,
gamma,
amax,
scale,
......
......@@ -29,7 +29,6 @@ pybind11::dict Registrations() {
dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8);
dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu);
dict["te_dgated_gelu_cast_transpose"] = EncapsulateFunction(DGatedGeluCastTranspose);
dict["te_gemm"] = EncapsulateFunction(Gemm);
dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward);
dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8);
dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward);
......@@ -56,14 +55,19 @@ pybind11::dict Registrations() {
PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("registrations", &Registrations);
m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor);
m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
m.def("get_fused_attn_backend", &GetFusedAttnBackend);
m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes);
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_self_fused_attn_fwd_workspace_sizes", &GetSelfFusedAttnForwardWorkspaceSizes);
m.def("get_self_fused_attn_bwd_workspace_sizes", &GetSelfFusedAttnBackwardWorkspaceSizes);
m.def("get_cross_fused_attn_fwd_workspace_sizes", &GetCrossFusedAttnForwardWorkspaceSizes);
m.def("get_cross_fused_attn_bwd_workspace_sizes", &GetCrossFusedAttnBackwardWorkspaceSizes);
pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte)
......
......@@ -22,7 +22,6 @@
#include "transformer_engine/activation.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/gemm.h"
#include "transformer_engine/layer_norm.h"
#include "transformer_engine/rmsnorm.h"
#include "transformer_engine/softmax.h"
......@@ -33,11 +32,12 @@
namespace transformer_engine {
namespace jax {
constexpr size_t kCublasLtForwardWorkspaceSize = 32 * 1024 * 1024;
constexpr size_t kCublasLtBackwardWorkspaceSize = 32 * 1024 * 1024;
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
std::vector<size_t> MakeShapeVector(NVTEShape shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}
template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) {
auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
......@@ -61,33 +61,37 @@ pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape,
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, DType A_dtype,
DType B_dtype, DType D_dtype, bool transa, bool transb,
bool use_split_accumulator) {
return PackOpaque(CustomCallGemmDescriptor{m, n, k, A_dtype, B_dtype, D_dtype, transa, transb,
use_split_accumulator});
}
pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype,
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
DType x_dtype, DType w_dtype,
DType wkspace_dtype, DType barrier_dtype,
DType dgamma_part_dtype, DType dbeta_part_dtype,
bool zero_centered_gamma, float eps, int sm_margin) {
return PackOpaque(
CustomCallNormDescriptor{n, hidden, x_dtype, w_dtype, zero_centered_gamma, eps, sm_margin});
return PackOpaque(CustomCallNormDescriptor{batch_size, hidden_size, wkspace_size, barrier_size,
dgamma_part_sizes, dbeta_part_sizes,
x_dtype, w_dtype, wkspace_dtype, barrier_dtype,
dgamma_part_dtype, dbeta_part_dtype,
zero_centered_gamma, eps, sm_margin});
}
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, size_t heads,
size_t q_seqlen, size_t k_seqlen, DType dtype,
float scale_factor) {
return PackOpaque(
SoftmaxDescriptor{batch, pad_batch, heads, q_seqlen, k_seqlen, dtype, scale_factor});
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
size_t head_dim, size_t q_seqlen, size_t k_seqlen,
DType dtype, float scale_factor) {
return PackOpaque(SoftmaxDescriptor{batch_size, padding_size, head_dim, q_seqlen, k_seqlen,
dtype, scale_factor});
}
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch, size_t num_head, size_t num_gqa_groups, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t wkspace_size,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
DType dtype, DType wkspace_dtype, bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{
batch, num_head, num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, scaling_factor,
dropout_probability, bias_type, mask_type, dtype, is_training});
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim, wkspace_size,
scaling_factor, dropout_probability, bias_type, mask_type, dtype, wkspace_dtype,
is_training});
}
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
......@@ -247,48 +251,56 @@ void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *op
output_trans_tensor.data(), stream);
}
void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *A = buffers[0];
auto *B = buffers[1];
auto *A_scale_inverse = reinterpret_cast<float *>(buffers[2]);
auto *B_scale_inverse = reinterpret_cast<float *>(buffers[3]);
auto *D = buffers[4];
// We transposes shape of A, B and D here to correctly invoke
// cuBlasLt GEMM (col-major) for row-major data.
const auto &desc = *UnpackOpaque<CustomCallGemmDescriptor>(opaque, opaque_len);
auto m = desc.m;
auto n = desc.n;
auto k = desc.k;
auto A_shape = std::vector<size_t>{k, m};
auto A_tensor = TensorWrapper(A, A_shape, desc.A_dtype, nullptr, nullptr, A_scale_inverse);
auto B_shape = std::vector<size_t>{n, k};
auto B_tensor = TensorWrapper(B, B_shape, desc.B_dtype, nullptr, nullptr, B_scale_inverse);
auto D_shape = std::vector<size_t>{n, m};
auto D_tensor = TensorWrapper(D, D_shape, desc.D_dtype);
auto null_tensor = TensorWrapper(nullptr, std::vector<size_t>{0}, DType::kFloat32);
pybind11::tuple GetLayerNormForwardWorkspaceSizes(
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma, float eps
) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
// empty tensor wrappers are okay just to get workspace size
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, input_shape, out_dtype);
auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
// dummy tensor wrappers that will carry workspace size info later
TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
if (is_layer_norm) {
auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
size_t workspace_size = kCublasLtForwardWorkspaceSize;
auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
auto wk_tensor = TensorWrapper(workspace, std::vector<size_t>{workspace_size}, DType::kByte);
layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), nullptr,
num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
rsigma_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
dummy_barrier_tensor.data());
}
nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), null_tensor.data(),
null_tensor.data(), (desc.transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(desc.transb) ? CUBLAS_OP_T : CUBLAS_OP_N, false, wk_tensor.data(), false,
desc.use_split_accumulator, 0, stream);
auto work_shape = MakeShapeVector(dummy_work_tensor.shape());
auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()),
std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()));
}
void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps,
int sm_margin, void *input, DType in_dtype, void *weight, DType w_dtype,
void *bias, void *output, DType out_dtype, void *mu, void *rsigma,
float *amax, float *scale, float *scale_inv, cudaStream_t stream) {
auto input_shape = std::vector<size_t>{n, hidden};
auto weight_shape = std::vector<size_t>{hidden};
auto intermediates_shape = std::vector<size_t>{n};
void LayerNormForwardImpl(size_t batch_size, size_t hidden_size,
size_t workspace_size, size_t barrier_size,
bool zero_centered_gamma, float eps, void *input, DType in_dtype,
void *weight, DType w_dtype, void *bias, void *output, DType out_dtype,
void *workspace, DType work_dtype, void *barrier, DType barrier_dtype,
void *mu, void *rsigma, float *amax, float *scale, float *scale_inv,
cudaStream_t stream) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
auto workspace_shape = std::vector<size_t>{workspace_size};
auto barrier_shape = std::vector<size_t>{barrier_size};
auto is_layer_norm = (bias) ? true : false;
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
......@@ -300,63 +312,95 @@ void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, flo
auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv);
auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32);
// Create uninitialized workspace, barrier and init them on the first
TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
if (!is_layer_norm) {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
}
// The first call is to query the required workspace
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype);
auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
if (is_layer_norm) {
auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype);
auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);
layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream,
num_sm, dummy_workspace_tensor.data(), dummy_barrier_tensor.data());
num_sm, workspace_tensor.data(), barrier_tensor.data());
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
rsigma_tensor.data(), stream, num_sm, dummy_workspace_tensor.data(),
dummy_barrier_tensor.data());
rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(),
barrier_tensor.data());
}
}
size_t workspace_size =
dummy_workspace_tensor.shape().data[0] * typeToSize(dummy_workspace_tensor.dtype()) +
dummy_barrier_tensor.shape().data[0] * typeToSize(dummy_barrier_tensor.dtype());
void *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm,
bool zero_centered_gamma, float eps
) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
auto intermediates_dtype = DType::kFloat32;
auto workspace_tensor =
TensorWrapper(workspace, dummy_workspace_tensor.shape(), dummy_workspace_tensor.dtype());
// empty tensor wrappers are okay just to get workspace size
auto dz_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
auto x_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto gamma_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
auto xgrad_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto wgrad_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
auto barrier_tensor =
TensorWrapper(reinterpret_cast<char *>(workspace) + dummy_workspace_tensor.shape().data[0],
dummy_barrier_tensor.shape(), dummy_barrier_tensor.dtype());
// dummy tensor wrappers that will carry workspace size info later
TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
// initialize dBeta information here -- layernorm will modify but RMSnorm will not
std::vector<size_t> dbeta_part_shape;
if (is_layer_norm) {
auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype);
auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);
auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream,
num_sm, workspace_tensor.data(), barrier_tensor.data());
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
wgrad_tensor.data(), dbeta_tensor.data(),
dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), nullptr,
num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
dbeta_part_shape = MakeShapeVector(dummy_dbeta_part_tensor.shape());
} else {
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(),
barrier_tensor.data());
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dummy_dgamma_part_tensor.data(), nullptr,
num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
dbeta_part_shape = std::vector<size_t>{0, 0};
}
auto work_shape = MakeShapeVector(dummy_work_tensor.shape());
auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape());
auto dgamma_part_shape = MakeShapeVector(dummy_dgamma_part_tensor.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()),
std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()),
std::make_pair(dgamma_part_shape, dummy_dgamma_part_tensor.dtype()),
std::make_pair(dbeta_part_shape, dummy_dbeta_part_tensor.dtype()));
}
void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps,
int sm_margin, void *input, DType in_dtype, void *weight, DType w_dtype,
void *ograd, void *mu, void *rsigma, void *xgrad, void *wgrad,
void *dbeta, cudaStream_t stream) {
auto input_shape = std::vector<size_t>{n, hidden};
auto weight_shape = std::vector<size_t>{hidden};
auto intermediates_shape = std::vector<size_t>{n};
void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size,
size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
bool zero_centered_gamma, float eps,
void *input, DType in_dtype, void *weight, DType w_dtype, void *ograd,
void *workspace, DType wkspace_dtype, void *barrier, DType barrier_dtype,
void *mu, void *rsigma, void *xgrad, void *wgrad, void *dbeta,
void *dgamma_part, DType dgamma_dtype,
void* dbeta_part, DType dbeta_dtype,
cudaStream_t stream) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
auto intermediates_dtype = DType::kFloat32;
auto is_layer_norm = (dbeta) ? true : false;
......@@ -374,62 +418,21 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, fl
auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype);
auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype);
TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor;
TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
size_t dbeta_part_size{};
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
if (!is_layer_norm) {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
}
// The first call is to query the workspace
if (is_layer_norm) {
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
wgrad_tensor.data(), dbeta_tensor.data(),
dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), stream,
num_sm, dummy_workspace_tensor.data(), dummy_barrier_tensor.data());
dbeta_part_size = dummy_dbeta_part_tensor.shape().data[0] *
dummy_dbeta_part_tensor.shape().data[1] *
typeToSize(dummy_dbeta_part_tensor.dtype());
} else {
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dummy_dgamma_part_tensor.data(), stream, num_sm,
dummy_workspace_tensor.data(), dummy_barrier_tensor.data());
}
size_t workspace_size =
dummy_workspace_tensor.shape().data[0] * typeToSize(dummy_workspace_tensor.dtype());
size_t barrier_size =
dummy_barrier_tensor.shape().data[0] * typeToSize(dummy_barrier_tensor.dtype());
size_t dgamma_part_size = dummy_dgamma_part_tensor.shape().data[0] *
dummy_dgamma_part_tensor.shape().data[1] *
typeToSize(dummy_dgamma_part_tensor.dtype());
auto [workspace, dgamma_part, dbeta_part, barrier] = WorkspaceManager::Instance().GetWorkspace(
workspace_size, dgamma_part_size, dbeta_part_size, barrier_size);
auto workspace_tensor =
TensorWrapper(workspace, dummy_workspace_tensor.shape(), dummy_workspace_tensor.dtype());
auto barrier_tensor =
TensorWrapper(barrier, dummy_barrier_tensor.shape(), dummy_barrier_tensor.dtype());
auto dgamma_part_tensor = TensorWrapper(dgamma_part, dummy_dgamma_part_tensor.shape(),
dummy_dgamma_part_tensor.dtype());
auto workspace_shape = std::vector<size_t>{wkspace_size};
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
auto barrier_shape = std::vector<size_t>{barrier_size};
auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
auto dgamma_part_shape = std::vector<size_t>{dgamma_part_sizes[0], dgamma_part_sizes[1]};
auto dgamma_part_tensor = TensorWrapper(dgamma_part, dgamma_part_shape, dgamma_dtype);
if (is_layer_norm) {
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
auto dbeta_part_tensor = TensorWrapper(dbeta_part, dummy_dbeta_part_tensor.shape(),
dummy_dbeta_part_tensor.dtype());
auto dbeta_part_shape = std::vector<size_t>{dbeta_part_sizes[0], dbeta_part_sizes[1]};
auto dbeta_part_tensor = TensorWrapper(dbeta_part, dbeta_part_shape, dbeta_dtype);
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
......@@ -437,6 +440,7 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, fl
dbeta_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
barrier_tensor.data());
} else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dgamma_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
......@@ -456,22 +460,29 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto *mu = buffers[7];
auto *rsigma = buffers[8];
auto *amax_out = buffers[9];
auto *workspace = buffers[10];
auto *barrier = buffers[11];
assert(amax_out == amax);
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n;
auto hidden = desc.hidden;
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias,
output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
}
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -481,33 +492,48 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto *output = buffers[3];
auto *mu = buffers[4];
auto *rsigma = buffers[5];
auto *workspace = buffers[6];
auto *barrier = buffers[7];
float *amax = nullptr;
float *scale = nullptr;
float *scale_inv = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n;
auto hidden = desc.hidden;
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto out_dtype = in_dtype;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias,
output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
}
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n;
auto hidden = desc.hidden;
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto *dgamma_part_sizes = desc.dgamma_part_sizes;
auto *dbeta_part_sizes = desc.dbeta_part_sizes;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto dgamma_part_dtype = desc.dgamma_part_dtype;
auto dbeta_part_dtype = desc.dbeta_part_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
......@@ -520,9 +546,16 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto *xgrad = buffers[5];
auto *wgrad = buffers[6];
auto *dbeta = buffers[7];
LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
w_dtype, ograd, mu, rsigma, xgrad, wgrad, dbeta, stream);
auto *workspace = buffers[8];
auto *barrier = buffers[9];
auto *dgamma_part = buffers[10];
auto *dbeta_part = buffers[11];
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
dgamma_part_sizes, dbeta_part_sizes, zero_centered_gamma, eps,
input, in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype,
barrier, barrier_dtype, mu, rsigma, xgrad, wgrad, dbeta,
dgamma_part, dgamma_part_dtype, dbeta_part, dbeta_part_dtype, stream);
}
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -534,24 +567,31 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto *output = buffers[5];
auto *rsigma = buffers[6];
auto *amax_out = buffers[7];
auto *workspace = buffers[8];
auto *barrier = buffers[9];
assert(amax_out == amax);
void *bias = nullptr;
void *mu = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n;
auto hidden = desc.hidden;
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias,
output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
}
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -559,6 +599,8 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz
auto *weight = buffers[1];
auto *output = buffers[2];
auto *rsigma = buffers[3];
auto *workspace = buffers[4];
auto *barrier = buffers[5];
void *bias = nullptr;
void *mu = nullptr;
......@@ -567,18 +609,23 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz
float *scale_inv = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n;
auto hidden = desc.hidden;
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = in_dtype;
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias,
output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
}
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -588,21 +635,35 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
auto *weight = buffers[3];
auto *xgrad = buffers[4];
auto *wgrad = buffers[5];
auto *workspace = buffers[6];
auto *barrier = buffers[7];
auto *dgamma_part = buffers[8];
void *mu = nullptr;
void *dbeta = nullptr;
void *dbeta_part = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n;
auto hidden = desc.hidden;
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto dgamma_part_sizes = desc.dgamma_part_sizes;
size_t dbeta_part_sizes[2] = {0, 0};
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto barrier_dtype = desc.barrier_dtype;
auto dgamma_part_dtype = desc.dgamma_part_dtype;
auto dbeta_part_dtype = DType::kByte;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
void *mu = nullptr;
void *dbeta = nullptr;
LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
w_dtype, ograd, mu, rsigma, xgrad, wgrad, dbeta, stream);
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
dgamma_part_sizes, dbeta_part_sizes, zero_centered_gamma, eps,
input, in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype,
barrier, barrier_dtype, mu, rsigma, xgrad, wgrad, dbeta,
dgamma_part, dgamma_part_dtype, dbeta_part, dbeta_part_dtype, stream);
}
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -645,7 +706,7 @@ void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaqu
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto shape = std::vector<size_t>{desc.batch, desc.heads, desc.q_seqlen, desc.k_seqlen};
auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, shape, dtype);
......@@ -662,7 +723,7 @@ void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaq
auto *dgrad = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto shape = std::vector<size_t>{desc.batch, desc.heads, desc.q_seqlen, desc.k_seqlen};
auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
......@@ -680,8 +741,9 @@ void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto io_shape = std::vector<size_t>{desc.batch, desc.heads, desc.q_seqlen, desc.k_seqlen};
auto mask_shape = std::vector<size_t>{desc.pad_batch, 1, desc.q_seqlen, desc.k_seqlen};
auto io_shape = std::vector<size_t>{desc.batch_size, desc.head_dim,
desc.q_seqlen, desc.k_seqlen};
auto mask_shape = std::vector<size_t>{desc.padding_size, 1, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, io_shape, dtype);
......@@ -705,7 +767,7 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto attn_batch = desc.batch * desc.heads;
auto attn_batch = desc.batch_size * desc.head_dim;
auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
......@@ -724,7 +786,7 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
auto *dgrad = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto attn_batch = desc.batch * desc.heads;
auto attn_batch = desc.batch_size * desc.head_dim;
auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
......@@ -750,91 +812,225 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
return backend;
}
/*
NOTE: PrepareFusedAttnForwardAuxTensors unifies the auxiliary tensor pack logic from the fused
attention forward kernels in:
- common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812
- common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359
*/
void PrepareFusedAttnForwardAuxTensors(
NVTETensorPack *tensor_pack, const CustomCallFusedAttnDescriptor *desc,
NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
void *softmax_buf, void *rng_state_buf = nullptr, void *bias_buf = nullptr
) {
auto batch_size = desc->batch_size;
auto num_heads = desc->num_heads;
auto q_max_seqlen = desc->q_max_seqlen;
auto kv_max_seqlen = desc->kv_max_seqlen;
// all backends need softmax but expect different shapes/dtypes
// start with the max512 sequence length softmax shape/dtype and correct later
tensor_pack->size = 1;
Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
softmax_aux->data.dptr = softmax_buf;
softmax_aux->data.shape = std::vector<size_t>{
batch_size, num_heads, q_max_seqlen, kv_max_seqlen};
softmax_aux->data.dtype = desc->dtype;
// arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
tensor_pack->size = 2;
Tensor *rng_state_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[1]);
rng_state_aux->data.dptr = rng_state_buf;
rng_state_aux->data.shape = std::vector<size_t>{2};
rng_state_aux->data.dtype = DType::kInt64;
// correct softmax shape/dtype
softmax_aux->data.shape.at(3) = 1; // {B,H,Qs,Ks} -> {B,H,Qs,1}
softmax_aux->data.dtype = DType::kFloat32;
// include bias if enabled
if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) {
tensor_pack->size = 3;
Tensor *bias_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[2]);
bias_aux->data.dptr = bias_buf;
bias_aux->data.shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
bias_aux->data.dtype = desc->dtype;
}
}
}
/*
NOTE: Backward fused attention kernels accept auxiliary tensors as explicit function arguments
instead of an NVTETensorPack and nvte_fused_attn_bwd() API does all the logic for pulling the
necessary tensors out of the tensor pack for the active kernel. That means we can just dump
everything we got into the tensor pack and not worry about its sizing for the backward pass.
TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()?
*/
void PrepareFusedAttnBackwardAuxTensors(
NVTETensorPack* tensor_pack, const CustomCallFusedAttnDescriptor *desc,
NVTE_Fused_Attn_Backend backend, void* softmax_buf, void* rng_state_buf, void* bias_buf
) {
// Backward calls put everything into the tensor pack for every backend
// so we set dummy bias_type and backend choices here to follow the correct code path
auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
PrepareFusedAttnForwardAuxTensors(tensor_pack, desc, dummy_bias_type, dummy_backend,
softmax_buf, rng_state_buf, bias_buf);
// correct softmax shape for max512 sequence length kernel
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
Tensor* softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
softmax_aux->data.shape.at(3) = desc->kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks}
softmax_aux->data.dtype = desc->dtype;
}
}
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, max_seqlen, max_seqlen};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto cu_seqlens_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto o_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size * max_seqlen, num_heads, head_dim}, dtype);
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, num_heads, num_heads,
max_seqlen, max_seqlen, head_dim);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input
// input buffers from XLA
void *qkv = buffers[0];
void *bias = buffers[1];
void *cu_seqlens = buffers[2];
void *seed = buffers[3];
// output
// output buffers from XLA
void *output = buffers[4];
void *softmax_aux = buffers[5];
void *rng_state = buffers[6];
void *workspace = buffers[7];
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
// tensor sizes
auto batch_size = descriptor.batch_size;
auto max_seqlen = descriptor.q_max_seqlen;
auto num_heads = descriptor.num_heads;
auto head_dim = descriptor.head_dim;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
NVTE_CHECK(num_head == num_gqa_groups,
"num_head should be equal to num_gqa_groups in the qkvpacked attention");
auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, max_seqlen, max_seqlen};
// input tensors
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto cu_seqlens_tensor = TensorWrapper(
cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// output tensors
auto o_tensor =
TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
auto o_tensor = TensorWrapper(
output, std::vector<size_t>{batch_size * max_seqlen, num_heads, head_dim}, dtype);
// aux tensors
// prep RNG state
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, num_heads, num_heads,
max_seqlen, max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, max_seqlen, max_seqlen, backend, stream);
auto backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
qkv_layout, bias_type, mask_type, dropout_probability, num_head,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
// auxiliary tensors (to be propagated to the backward pass later)
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
softmax_aux);
TensorWrapper query_workspace_tensor;
// cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
rng_state_tensor.data(), max_seqlen, descriptor.is_training,
descriptor.scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), stream);
bias_type, mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux;
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto workspace_size = query_workspace_tensor.shape().data[0];
auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim};
auto output_shape = std::vector<size_t>{batch_size * max_seqlen, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, max_seqlen, max_seqlen};
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
nvte_tensor_pack_destroy(&aux_output_tensors);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto cu_seqlens_tensor = TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1},
DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
......@@ -842,7 +1038,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input
// input buffers from XLA
void *qkv = buffers[0];
void *bias = buffers[1];
void *softmax_aux = buffers[2];
......@@ -851,82 +1047,107 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
void *doutput = buffers[5];
void *cu_seqlens = buffers[6];
// output
// output buffers from XLA
void *dqkv = buffers[7];
void *dbias = buffers[8];
void *workspace = buffers[9];
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
// tensor sizes
auto batch_size = descriptor.batch_size;
auto max_seqlen = descriptor.q_max_seqlen;
auto num_heads = descriptor.num_heads;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
NVTE_CHECK(num_head == num_gqa_groups,
"num_head should be equal to num_gqa_groups in the qkvpacked attention");
auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim};
auto output_shape = std::vector<size_t>{batch_size * max_seqlen, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, max_seqlen, max_seqlen};
// input tensors
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
TensorWrapper(cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// TODO: needs to think about how to pass aux_output_tensors
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
aux_output_tensors.size = 3;
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux;
auto *rng_state_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[1]);
rng_state_tensor->data.shape = std::vector<size_t>{2};
rng_state_tensor->data.dtype = DType::kInt64;
rng_state_tensor->data.dptr = rng_state;
auto *bias_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[2]);
bias_tensor->data = SimpleTensor(bias, bias_shape, dtype);
// auxiliary tensors (propagated from the forward pass)
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, num_heads, num_heads,
max_seqlen, max_seqlen, head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend,
softmax_aux, rng_state, bias);
TensorWrapper query_workspace_tensor;
// cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor,
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), stream);
workspace_tensor.data(), stream);
size_t workspace_size = query_workspace_tensor.shape().data[0];
auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
nvte_tensor_pack_destroy(&aux_input_tensors);
}
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
nvte_tensor_pack_destroy(&aux_output_tensors);
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto kv_shape = std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// TODO(rewang): add bias for cross attn?
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto q_cu_seqlens_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
......@@ -934,7 +1155,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input
// input buffers from XLA
void *q = buffers[0];
void *kv = buffers[1];
void *bias = buffers[2];
......@@ -942,83 +1163,115 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
void *kv_cu_seqlens = buffers[4];
void *seed = buffers[5];
// output
// output buffers from XLA
void *output = buffers[6];
void *softmax_aux = buffers[7];
void *rng_state = buffers[8];
void *workspace = buffers[9];
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
// tensor sizes
auto batch_size = descriptor.batch_size;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto num_heads = descriptor.num_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto kv_shape = std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
auto o_tensor = TensorWrapper(output, q_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
// output tensors
auto o_tensor =
TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype);
// aux tensors
// F16 doesn't use s_tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// prep RNG state
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
qkv_layout, bias_type, mask_type, dropout_probability, num_head,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, num_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
// auxiliary tensors (to be propagated to the backward pass later)
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
softmax_aux);
TensorWrapper query_workspace_tensor;
// cuDNN workspace
auto workspace_tensor = TensorWrapper(
workspace, std::vector<size_t>{descriptor.wkspace_size}, descriptor.wkspace_dtype);
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), stream);
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux;
nvte_tensor_pack_destroy(&aux_output_tensors);
}
auto workspace_size = query_workspace_tensor.shape().data[0];
auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto kv_shape = std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
nvte_tensor_pack_destroy(&aux_output_tensors);
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto q_cu_seqlens_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
......@@ -1026,7 +1279,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input
// input buffers from XLA
void *q = buffers[0];
void *kv = buffers[1];
void *bias = buffers[2];
......@@ -1037,85 +1290,72 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
void *q_cu_seqlens = buffers[7];
void *kv_cu_seqlens = buffers[8];
// output
// output buffers from XLA
void *dq = buffers[9];
void *dkv = buffers[10];
void *dbias = buffers[11];
void *workspace = buffers[12];
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
// tensor sizes
auto batch_size = descriptor.batch_size;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto num_heads = descriptor.num_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto kv_shape = std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto q_cu_seqlens_tensor = TensorWrapper(
q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = TensorWrapper(
kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// auxiliary tensors (propagated from the forward pass)
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, num_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend,
softmax_aux, rng_state, bias);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
// TODO(rewang): need to think about how to pass aux_output_tensors
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
aux_output_tensors.size = 3;
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux;
auto *rng_state_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[1]);
rng_state_tensor->data.shape = std::vector<size_t>{2};
rng_state_tensor->data.dtype = DType::kInt64;
rng_state_tensor->data.dptr = rng_state;
auto *bias_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[2]);
bias_tensor->data = SimpleTensor(bias, bias_shape, dtype);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.scaling_factor, dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD, bias_type,
mask_type, query_workspace_tensor.data(), stream);
size_t workspace_size = query_workspace_tensor.shape().data[0];
auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
// cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.scaling_factor, dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD, bias_type,
mask_type, workspace_tensor.data(), stream);
scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
nvte_tensor_pack_destroy(&aux_input_tensors);
}
} // namespace jax
......
......@@ -52,68 +52,69 @@ struct CustomCallCommonDescriptor {
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype);
struct CustomCallGemmDescriptor {
size_t m;
size_t n;
size_t k;
DType A_dtype;
DType B_dtype;
DType D_dtype;
bool transa;
bool transb;
bool use_split_accumulator;
};
pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, DType A_dtype,
DType B_dtype, DType D_dtype, bool transa, bool transb,
bool use_split_accumulator);
struct CustomCallNormDescriptor {
size_t n;
size_t hidden;
size_t batch_size;
size_t hidden_size;
size_t wkspace_size;
size_t barrier_size;
size_t *dgamma_part_sizes; // 2D tensor
size_t *dbeta_part_sizes; // 2D tensor
DType x_dtype;
DType w_dtype;
DType wkspace_dtype;
DType barrier_dtype;
DType dgamma_part_dtype;
DType dbeta_part_dtype;
bool zero_centered_gamma;
float eps;
int sm_margin;
};
pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype,
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
DType x_dtype, DType w_dtype,
DType wkspace_dtype, DType barrier_dtype,
DType dgamma_part_dtype, DType dbeta_part_dtype,
bool zero_centered_gamma, float eps, int sm_margin);
struct SoftmaxDescriptor {
size_t batch;
size_t pad_batch;
size_t heads;
size_t batch_size;
size_t padding_size;
size_t head_dim;
size_t q_seqlen;
size_t k_seqlen;
DType dtype;
float scale_factor;
};
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, size_t heads,
size_t q_seqlen, size_t k_seqlen, DType dtype,
float scale_factor);
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
size_t head_dim, size_t q_seqlen, size_t k_seqlen,
DType dtype, float scale_factor);
struct CustomCallFusedAttnDescriptor {
size_t batch;
size_t num_head;
size_t num_gqa_groups;
size_t batch_size;
size_t q_max_seqlen;
size_t kv_max_seqlen;
size_t num_heads;
size_t num_gqa_groups;
size_t head_dim;
size_t wkspace_size;
float scaling_factor;
float dropout_probability;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
DType dtype;
DType wkspace_dtype;
bool is_training;
};
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch, size_t num_head, size_t num_gqa_groups, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, DType dtype, bool is_training);
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t wkspace_size,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
DType dtype, DType wkspace_dtype, bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......@@ -135,13 +136,21 @@ void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t
void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetLayerNormForwardWorkspaceSizes(
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma, float eps
);
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm,
bool zero_centered_gamma, float eps
);
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......@@ -172,15 +181,41 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
......
......@@ -28,66 +28,6 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream);
class WorkspaceManager {
public:
static WorkspaceManager &Instance() {
static thread_local WorkspaceManager instance;
return instance;
}
WorkspaceManager() {}
~WorkspaceManager() { Clear_(); }
void *GetWorkspace(size_t size = 4194304) {
ReallocateIfNeed_(size);
return workspace_;
}
template <typename... Args>
inline auto GetWorkspace(Args... args) {
auto asks = std::array<size_t, sizeof...(Args)>{args...};
std::array<size_t, sizeof...(Args) + 1> offsets = {0};
std::array<void *, sizeof...(Args)> workspaces = {nullptr};
std::transform_inclusive_scan(
asks.cbegin(), asks.cend(), offsets.begin() + 1, std::plus<size_t>{},
[=](auto x) { return PadSize_(x); }, 0);
auto *workspace = GetWorkspace(offsets.back());
std::transform(offsets.cbegin(), offsets.cend() - 1, workspaces.begin(),
[workspace](auto x) { return static_cast<char *>(workspace) + x; });
return workspaces;
}
private:
void *workspace_ = nullptr;
size_t size_ = 0;
size_t PadSize_(size_t size) {
constexpr size_t alignment = 128;
return ((size + alignment - 1) / alignment) * alignment;
}
void Clear_() {
if (workspace_ != nullptr) {
NVTE_CHECK_CUDA(cudaFree(workspace_));
}
workspace_ = nullptr;
size_ = 0;
}
void Allocate_(size_t new_size) {
new_size = PadSize_(new_size);
NVTE_CHECK_CUDA(cudaMalloc(&workspace_, new_size));
size_ = new_size;
}
void ReallocateIfNeed_(size_t new_size) {
if (new_size > size_) {
Clear_();
Allocate_(new_size);
}
}
};
class cudaDevicePropertiesManager {
public:
static cudaDevicePropertiesManager &Instance() {
......
......@@ -22,7 +22,7 @@ from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage
from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot
from ..mlp import layernrom_geglu_fp8_mlp, geglu
from ..mlp import layernorm_geglu_fp8_mlp, geglu
from ..softmax import is_softmax_kernel_available
from ..softmax import softmax, SoftmaxType
......@@ -886,7 +886,7 @@ class LayerNormMLP(TransformerEngineBase):
if use_fused_ln_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
out = layernrom_geglu_fp8_mlp(y,
out = layernorm_geglu_fp8_mlp(y,
scale,
ln_bias, [kernel_1, kernel_2],
fp8_meta_package,
......
......@@ -55,7 +55,7 @@ def _geglu_bwd_rule(ctx, g):
_geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule)
def layernrom_geglu_fp8_mlp(x: jnp.ndarray,
def layernorm_geglu_fp8_mlp(x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
kernels: List[jnp.ndarray],
......@@ -86,25 +86,25 @@ def layernrom_geglu_fp8_mlp(x: jnp.ndarray,
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
output = _layernrom_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale,
output = _layernorm_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale,
scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
zero_centered_gamma, epsilon)
return output
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13))
def _layernrom_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
def _layernorm_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray,
amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str,
zero_centered_gamma: bool, epsilon: float):
output, _ = _layernrom_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax,
output, _ = _layernorm_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax,
scale, scale_inv, fwd_dtype, bwd_dtype,
layernorm_type, zero_centered_gamma, epsilon)
return output
def _layernrom_geglu_fp8_mlp_fwd_rule(
def _layernorm_geglu_fp8_mlp_fwd_rule(
x,
gamma,
beta,
......@@ -209,7 +209,7 @@ def _layernrom_geglu_fp8_mlp_fwd_rule(
return dot_2_output, ctx
def _layernrom_geglu_fp8_mlp_bwd_rule(
def _layernorm_geglu_fp8_mlp_bwd_rule(
fwd_dtype, # pylint: disable=unused-argument
bwd_dtype,
layernorm_type,
......@@ -307,5 +307,5 @@ def _layernrom_geglu_fp8_mlp_bwd_rule(
fp8_max, amax, scale, scale_inv
_layernrom_geglu_fp8_mlp.defvjp(_layernrom_geglu_fp8_mlp_fwd_rule,
_layernrom_geglu_fp8_mlp_bwd_rule)
_layernorm_geglu_fp8_mlp.defvjp(_layernorm_geglu_fp8_mlp_fwd_rule,
_layernorm_geglu_fp8_mlp_bwd_rule)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment