"...git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "f25991370770c9f55ea8cf445e01301db61679d6"
Unverified Commit 18c2234c authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[JAX] XLA Custom Calls with FFI for FusedAttnFwd, Quantize, Transpose,...


[JAX] XLA Custom Calls with FFI for FusedAttnFwd, Quantize, Transpose, ActLuFP8, LayerNormForwardFP8FFI, and LayerNormBackwardFFI (#1263)

* Add TransposeFFI, test passed
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Add ActLuFP8FFI; fix TransposeFFI
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Add QuantizeFFI
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Add FusedAttnForwardFFI and some unit tests
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor fix
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Add LayerNormForwardFP8FFI & LayerNormBackwardFFI
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revise FusedAttnForwardFFI()
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add FFI_CudaGraph_Traits

All tests passed, ready for merge
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Bug fix for FFI data type mismatch

Also add a safeguard on the entrance to FFI function
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 20c75295
...@@ -3,9 +3,8 @@ ...@@ -3,9 +3,8 @@
# See LICENSE for license information. # See LICENSE for license information.
from contextlib import nullcontext from contextlib import nullcontext
import functools
import operator
from typing import Callable, List, Sequence, Union from typing import Callable, List, Sequence, Union
import os
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -14,12 +13,17 @@ import pytest ...@@ -14,12 +13,17 @@ import pytest
from jax import jit, value_and_grad from jax import jit, value_and_grad
from flax import linen as nn from flax import linen as nn
from utils import assert_allclose from utils import assert_allclose, assert_tree_like_allclose
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp
from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu
from transformer_engine.jax.cpp_extensions.transpose import (
_jax_transpose,
_jax_cast_transpose,
)
from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8
from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax import cpp_extensions as tex
...@@ -746,3 +750,102 @@ class TestNorm: ...@@ -746,3 +750,102 @@ class TestNorm:
assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE)
if beta is not None: if beta is not None:
assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.parametrize(
"in_dtype",
[
pytest.param(jnp.float32, id="input_float32"),
pytest.param(jnp.float16, id="input_float16"),
pytest.param(jnp.bfloat16, id="input_bfloat16"),
],
)
@pytest.mark.parametrize(
"input_shape, transpose_axis",
[
pytest.param((16, 16), 1, id="(16, 16)-1"),
pytest.param((256, 128), 1, id="(256, 128)-1"),
pytest.param((128, 512), 1, id="(128, 512)-1"),
pytest.param((64, 16, 4, 256), 1, id="(64, 16, 4, 256)-1"),
pytest.param((64, 16, 4, 256), 2, id="(64, 16, 4, 256)-2"),
pytest.param((64, 16, 4, 256), 3, id="(64, 16, 4, 256)-3"),
],
)
class TestTranspose:
def test_transpose(self, in_dtype, input_shape, transpose_axis):
key = jax.random.PRNGKey(0)
input_tensor = jax.random.uniform(key, input_shape, in_dtype)
static_axis_boundary = -1
jax_output = _jax_transpose(input_tensor, static_axis_boundary, transpose_axis)
os.environ["NVTE_JAX_WITH_FFI"] = "0"
noffi_output = tex.transpose(input_tensor, static_axis_boundary, transpose_axis)
os.environ["NVTE_JAX_WITH_FFI"] = "1"
ffi_output = tex.transpose(input_tensor, static_axis_boundary, transpose_axis)
assert_allclose(jax_output, noffi_output)
assert_allclose(noffi_output, ffi_output)
@pytest.mark.parametrize(
"out_dtype",
[
pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
],
)
def test_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype):
amax = jnp.zeros(1, jnp.float32)
scale = jnp.ones(1, jnp.float32)
scale_inv = jnp.ones(1, jnp.float32)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
static_axis_boundary = -1
jax_output = _jax_cast_transpose(
input, scale, amax, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "0"
noffi_output = tex.cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "1"
ffi_output = tex.cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
assert_tree_like_allclose(jax_output, ffi_output)
assert_tree_like_allclose(noffi_output, ffi_output)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize(
"input_shape",
[
pytest.param((256, 128), id="(256, 128)"),
pytest.param((128, 512, 8), id="(128, 512, 8)"),
],
)
@pytest.mark.parametrize(
"in_dtype",
[
pytest.param(jnp.float32, id="input_float32"),
pytest.param(jnp.float16, id="input_float16"),
pytest.param(jnp.bfloat16, id="input_bfloat16"),
],
)
@pytest.mark.parametrize(
"out_dtype",
[
pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
],
)
def test_quantize(input_shape, in_dtype, out_dtype):
amax = jnp.zeros(1, jnp.float32)
scale = jnp.ones(1, jnp.float32)
scale_inv = jnp.ones(1, jnp.float32)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
jax_output = _jax_cast_fp8(input, scale, amax, out_dtype)
os.environ["NVTE_JAX_WITH_FFI"] = "0"
noffi_output = tex.cast_fp8(input, amax, scale, scale_inv, out_dtype)
os.environ["NVTE_JAX_WITH_FFI"] = "1"
ffi_output = tex.cast_fp8(input, amax, scale, scale_inv, out_dtype)
assert_tree_like_allclose(jax_output, ffi_output)
assert_tree_like_allclose(noffi_output, ffi_output)
...@@ -67,7 +67,7 @@ BASE_ATTRS = { ...@@ -67,7 +67,7 @@ BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True, _KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_NUM_HEADS: 8, _KEY_OF_NUM_HEADS: 8,
_KEY_OF_HIDDEN_DROPOUT: 0, _KEY_OF_HIDDEN_DROPOUT: 0,
_KEY_OF_ATTENTION_DROPOUT: 0, _KEY_OF_ATTENTION_DROPOUT: 0.0,
_KEY_OF_INTERMEDIATE_DROPOUT: 0, _KEY_OF_INTERMEDIATE_DROPOUT: 0,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal", _KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
_KEY_OF_LAYERNORM_TYPE: "layernorm", _KEY_OF_LAYERNORM_TYPE: "layernorm",
......
...@@ -383,6 +383,12 @@ class ActLuFp8Primitive(BasePrimitive): ...@@ -383,6 +383,12 @@ class ActLuFp8Primitive(BasePrimitive):
assert amax_aval.dtype == jnp.float32 assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_act_lu_fp8_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={1: 1})(
ctx, x, amax, scale, scale_inv, act_enum=act_enum
)
else:
ir_x_type = ir.RankedTensorType(x.type) ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
......
...@@ -15,6 +15,7 @@ from jax import dtypes, lax ...@@ -15,6 +15,7 @@ from jax import dtypes, lax
from jax.interpreters import mlir from jax.interpreters import mlir
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import ( from transformer_engine.transformer_engine_jax import (
...@@ -33,6 +34,7 @@ from .misc import ( ...@@ -33,6 +34,7 @@ from .misc import (
te_dtype_to_jax_dtype, te_dtype_to_jax_dtype,
get_padded_spec, get_padded_spec,
get_cudnn_version, get_cudnn_version,
is_ffi_enabled,
) )
from ..sharding import ( from ..sharding import (
global_mesh_resource, global_mesh_resource,
...@@ -352,14 +354,6 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -352,14 +354,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
""" """
Fused attention fwd lowering rules Fused attention fwd lowering rules
""" """
operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, seed]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
...@@ -376,6 +370,60 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -376,6 +370,60 @@ class FusedAttnFwdPrimitive(BasePrimitive):
wkspace_aval = ctx.avals_out[-1] wkspace_aval = ctx.avals_out[-1]
if is_ffi_enabled():
name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)(
ctx,
q,
k,
v,
bias,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
input_batch=input_batch,
bias_batch=bias_batch,
q_max_seqlen=q_max_seqlen,
kv_max_seqlen=kv_max_seqlen,
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
head_dim=head_dim,
max_segments_per_seq=config.max_segments_per_seq,
wkspace_size=wkspace_aval.size,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type),
mask_type=int(config.attn_mask_type),
qkv_layout=int(config.qkv_layout),
dtype=int(jax_dtype_to_te_dtype(q_aval.dtype)),
wkspace_dtype=int(jax_dtype_to_te_dtype(wkspace_aval.dtype)),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0],
window_size_right=config.window_size[1],
)
else:
operands = [
q,
k,
v,
bias,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch, input_batch,
bias_batch, bias_batch,
......
...@@ -9,6 +9,7 @@ import jax.numpy as jnp ...@@ -9,6 +9,7 @@ import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType from transformer_engine.transformer_engine_jax import DType as TEDType
...@@ -20,6 +21,7 @@ from .misc import ( ...@@ -20,6 +21,7 @@ from .misc import (
check_valid_batch_dims, check_valid_batch_dims,
jax_dtype_to_te_dtype, jax_dtype_to_te_dtype,
jax_dtype_to_ir_dtype, jax_dtype_to_ir_dtype,
is_ffi_enabled,
) )
from ..sharding import all_reduce_max_along_all_axes_except_PP from ..sharding import all_reduce_max_along_all_axes_except_PP
...@@ -84,6 +86,12 @@ class CastFP8Primitive(BasePrimitive): ...@@ -84,6 +86,12 @@ class CastFP8Primitive(BasePrimitive):
assert amax_aval.dtype == jnp.float32 assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_quantize_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={1: 1})(
ctx, x, amax, scale, scale_inv
)
else:
ir_x_type = ir.RankedTensorType(x.type) ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
......
...@@ -102,6 +102,10 @@ class TransposePrimitive(BasePrimitive): ...@@ -102,6 +102,10 @@ class TransposePrimitive(BasePrimitive):
jnp.float8_e5m2, jnp.float8_e5m2,
] ]
if is_ffi_enabled():
name = "te_transpose_ffi"
out = ffi.ffi_lowering(name)(ctx, x, transpose_axis=transpose_axis_boundary)
else:
ir_x_type = ir.RankedTensorType(x.type) ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype) ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype)
......
...@@ -151,6 +151,8 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( ...@@ -151,6 +151,8 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(TransposeHandler);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
...@@ -172,6 +174,8 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq ...@@ -172,6 +174,8 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuFP8Handler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuHandler);
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
...@@ -195,6 +199,8 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -195,6 +199,8 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormForwardFP8Handler);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype, DType in_dtype, DType w_dtype,
bool is_layer_norm, bool zero_centered_gamma, bool is_layer_norm, bool zero_centered_gamma,
...@@ -202,6 +208,8 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid ...@@ -202,6 +208,8 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormBackwardHandler);
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
...@@ -212,6 +220,8 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si ...@@ -212,6 +220,8 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(QuantizeHandler);
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// Softmax // Softmax
...@@ -253,6 +263,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -253,6 +263,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
......
...@@ -153,6 +153,51 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op ...@@ -153,6 +153,51 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op
act_enum, act_len); act_enum, act_len);
} }
Error_Type ActLuFP8FFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type amax_out_buf, int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX ActLuFP8 primitive.");
if (!use_fp8(out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto input_dims = input_buf.dimensions();
auto m = std::accumulate(input_dims.begin(), input_dims.end() - 2, 1, std::multiplies<>());
auto n = input_dims.back();
auto act_len = input_dims.end()[-2];
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
ActLuImpl(input, m, n, in_dtype, out_dtype, scale, stream, scale_inv, amax_out, output, act_type,
act_len);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuFP8Handler, ActLuFP8FFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // amax_out
.Attr<int64_t>("act_enum"),
FFI_CudaGraph_Traits);
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
auto *act_input = buffers[1]; auto *act_input = buffers[1];
......
...@@ -30,18 +30,13 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, ...@@ -30,18 +30,13 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
- common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812 - common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812
- common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359 - common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359
*/ */
void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch,
const CustomCallFusedAttnDescriptor *desc, const size_t bias_batch, const size_t attn_heads,
const size_t bias_heads, const size_t q_max_seqlen,
const size_t kv_max_seqlen, DType dtype,
NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend, NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
void *softmax_buf, void *rng_state_buf = nullptr, void *softmax_buf, void *rng_state_buf = nullptr,
void *bias_buf = nullptr) { void *bias_buf = nullptr) {
auto input_batch = desc->input_batch;
auto bias_batch = desc->bias_batch;
auto attn_heads = desc->attn_heads;
auto bias_heads = desc->bias_heads;
auto q_max_seqlen = desc->q_max_seqlen;
auto kv_max_seqlen = desc->kv_max_seqlen;
// all backends need softmax but expect different shapes/dtypes // all backends need softmax but expect different shapes/dtypes
// start with the max512 sequence length softmax shape/dtype and correct later // start with the max512 sequence length softmax shape/dtype and correct later
tensor_pack->size = 1; tensor_pack->size = 1;
...@@ -49,7 +44,7 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, ...@@ -49,7 +44,7 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack,
softmax_aux->data.dptr = softmax_buf; softmax_aux->data.dptr = softmax_buf;
softmax_aux->data.shape = softmax_aux->data.shape =
std::vector<size_t>{input_batch, attn_heads, q_max_seqlen, kv_max_seqlen}; std::vector<size_t>{input_batch, attn_heads, q_max_seqlen, kv_max_seqlen};
softmax_aux->data.dtype = desc->dtype; softmax_aux->data.dtype = dtype;
// arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax // arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
...@@ -69,7 +64,7 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, ...@@ -69,7 +64,7 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack,
bias_aux->data.dptr = bias_buf; bias_aux->data.dptr = bias_buf;
bias_aux->data.shape = bias_aux->data.shape =
std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
bias_aux->data.dtype = desc->dtype; bias_aux->data.dtype = dtype;
} }
} }
} }
...@@ -82,22 +77,25 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, ...@@ -82,22 +77,25 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack,
TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()? TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()?
*/ */
void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch,
const CustomCallFusedAttnDescriptor *desc, const size_t bias_batch, const size_t attn_heads,
const size_t bias_heads, const size_t q_max_seqlen,
const size_t kv_max_seqlen, DType dtype,
NVTE_Fused_Attn_Backend backend, void *softmax_buf, NVTE_Fused_Attn_Backend backend, void *softmax_buf,
void *rng_state_buf, void *bias_buf) { void *rng_state_buf, void *bias_buf) {
// Backward calls put everything into the tensor pack for every backend // Backward calls put everything into the tensor pack for every backend
// so we set dummy bias_type and backend choices here to follow the correct code path // so we set dummy bias_type and backend choices here to follow the correct code path
auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS; auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
PrepareFusedAttnForwardAuxTensors(tensor_pack, desc, dummy_bias_type, dummy_backend, softmax_buf, PrepareFusedAttnForwardAuxTensors(tensor_pack, input_batch, bias_batch, attn_heads, bias_heads,
rng_state_buf, bias_buf); q_max_seqlen, kv_max_seqlen, dtype, dummy_bias_type,
dummy_backend, softmax_buf, rng_state_buf, bias_buf);
// correct softmax shape for max512 sequence length kernel // correct softmax shape for max512 sequence length kernel
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]); Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
softmax_aux->data.shape.at(3) = desc->kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks} softmax_aux->data.shape.at(3) = kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks}
softmax_aux->data.dtype = desc->dtype; softmax_aux->data.dtype = dtype;
} }
} }
...@@ -190,7 +188,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -190,7 +188,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor = const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
auto qkv_layout = descriptor.qkv_layout; auto qkv_layout = descriptor.qkv_layout;
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
...@@ -279,8 +276,9 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -279,8 +276,9 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
/* Auxiliary tensors (to be propagated to the backward pass later) */ /* Auxiliary tensors (to be propagated to the backward pass later) */
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend, PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads,
softmax_aux); bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type,
backend, softmax_aux);
/* cuDNN workspace */ /* cuDNN workspace */
auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size}, auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
...@@ -335,6 +333,201 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -335,6 +333,201 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
} }
Error_Type FusedAttnForwardFFI(
cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf,
Buffer_Type bias_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, Buffer_Type seed_buf,
Result_Type output_buf, Result_Type softmax_aux_buf, Result_Type rng_state_buf,
Result_Type workspace_buf, int64_t input_batch_, int64_t bias_batch_, int64_t q_max_seqlen_,
int64_t kv_max_seqlen_, int64_t attn_heads_, int64_t num_gqa_groups_, int64_t bias_heads_,
int64_t head_dim_, int64_t max_segments_per_seq_, int64_t wkspace_size_, double scaling_factor_,
double dropout_probability_, int64_t bias_type_, int64_t mask_type_, int64_t qkv_layout_,
int64_t dtype_, int64_t wkspace_dtype_, bool is_training, bool deterministic,
int64_t window_size_left, int64_t window_size_right) {
/* Descriptor data type conversion */
size_t input_batch = static_cast<size_t>(input_batch_);
size_t bias_batch = static_cast<size_t>(bias_batch_);
size_t q_max_seqlen = static_cast<size_t>(q_max_seqlen_);
size_t kv_max_seqlen = static_cast<size_t>(kv_max_seqlen_);
size_t attn_heads = static_cast<size_t>(attn_heads_);
size_t num_gqa_groups = static_cast<size_t>(num_gqa_groups_);
size_t bias_heads = static_cast<size_t>(bias_heads_);
size_t head_dim = static_cast<size_t>(head_dim_);
size_t max_segments_per_seq = static_cast<size_t>(max_segments_per_seq_);
size_t wkspace_size = static_cast<size_t>(wkspace_size_);
float scaling_factor = static_cast<float>(scaling_factor_);
float dropout_probability = static_cast<float>(dropout_probability_);
NVTE_Bias_Type bias_type = static_cast<NVTE_Bias_Type>(bias_type_);
NVTE_Mask_Type mask_type = static_cast<NVTE_Mask_Type>(mask_type_);
NVTE_QKV_Layout qkv_layout = static_cast<NVTE_QKV_Layout>(qkv_layout_);
DType dtype = static_cast<DType>(dtype_);
DType wkspace_dtype = static_cast<DType>(wkspace_dtype_);
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
/* Input buffers from XLA */
/* q, k, v are parsed later for different qkv_layout */
void *bias = bias_buf.untyped_data();
void *q_cu_seqlens = q_cu_seqlens_buf.untyped_data();
void *kv_cu_seqlens = kv_cu_seqlens_buf.untyped_data();
void *q_seq_offsets = is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr;
void *k_seq_offsets = is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr;
void *seed = seed_buf.untyped_data();
/* Output buffer from XLA */
void *output = output_buf->untyped_data();
void *softmax_aux = softmax_aux_buf->untyped_data();
void *rng_state = rng_state_buf->untyped_data();
void *workspace = workspace_buf->untyped_data();
/* Input tensors */
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments
if (is_ragged) {
auto cudnn_runtime_version = cudnnGetVersion();
if (cudnn_runtime_version >= 90300) {
num_segments = input_batch * max_segments_per_seq;
} else {
// workspace can be reused here as it is not used with cuDNN graph at the same time
size_t runtime_num_segments_q =
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
size_t runtime_num_segments_kv =
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
num_segments = runtime_num_segments_q;
}
auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim;
cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
}
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto q_seq_offsets_tensor =
TensorWrapper(q_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto k_seq_offsets_tensor =
TensorWrapper(k_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
/* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto o_tensor = TensorWrapper(output, o_shape, dtype);
/* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type,
backend, softmax_aux);
/* cuDNN workspace */
auto workspace_tensor =
TensorWrapper(workspace, std::vector<size_t>{wkspace_size}, wkspace_dtype);
/* Call the underlying NVTE API */
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv = q_buf.untyped_data();
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, is_training, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = q_buf.untyped_data();
auto kv = k_buf.untyped_data();
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = q_buf.untyped_data();
auto k = k_buf.untyped_data();
auto v = v_buf.untyped_data();
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
nvte_tensor_pack_destroy(&aux_output_tensors);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // q
.Arg<Buffer_Type>() // k
.Arg<Buffer_Type>() // v
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // q_cu_seqlens
.Arg<Buffer_Type>() // kv_cu_seqlens
.Arg<Buffer_Type>() // q_seq_offsets
.Arg<Buffer_Type>() // k_seq_offsets
.Arg<Buffer_Type>() // seed_buf
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // softmax_aux
.Ret<Buffer_Type>() // rng_state
.Ret<Buffer_Type>() // workspace
.Attr<int64_t>("input_batch")
.Attr<int64_t>("bias_batch")
.Attr<int64_t>("q_max_seqlen")
.Attr<int64_t>("kv_max_seqlen")
.Attr<int64_t>("attn_heads")
.Attr<int64_t>("num_gqa_groups")
.Attr<int64_t>("bias_heads")
.Attr<int64_t>("head_dim")
.Attr<int64_t>("max_segments_per_seq")
.Attr<int64_t>("wkspace_size")
.Attr<double>("scaling_factor")
.Attr<double>("dropout_probability")
.Attr<int64_t>("bias_type")
.Attr<int64_t>("mask_type")
.Attr<int64_t>("qkv_layout")
.Attr<int64_t>("dtype")
.Attr<int64_t>("wkspace_dtype")
.Attr<bool>("is_training")
.Attr<bool>("deterministic")
.Attr<int64_t>("window_size_left")
.Attr<int64_t>("window_size_right"),
FFI_CudaGraph_Traits);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
...@@ -523,8 +716,9 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -523,8 +716,9 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right); head_dim, head_dim, window_size_left, window_size_right);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
rng_state, bias); bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias);
/* cuDNN workspace */ /* cuDNN workspace */
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size}; auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
......
...@@ -15,12 +15,21 @@ namespace jax { ...@@ -15,12 +15,21 @@ namespace jax {
// For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186 // For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
switch (type) { switch (type) {
case xla::ffi::DataType::F16: case xla::ffi::DataType::U8:
return DType::kFloat16; return DType::kByte;
break;
case xla::ffi::DataType::S32:
return DType::kInt32;
break;
case xla::ffi::DataType::S64:
return DType::kInt64;
break; break;
case xla::ffi::DataType::F32: case xla::ffi::DataType::F32:
return DType::kFloat32; return DType::kFloat32;
break; break;
case xla::ffi::DataType::F16:
return DType::kFloat16;
break;
case xla::ffi::DataType::BF16: case xla::ffi::DataType::BF16:
return DType::kBFloat16; return DType::kBFloat16;
break; break;
......
...@@ -237,6 +237,78 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque ...@@ -237,6 +237,78 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
sm_margin, stream); sm_margin, stream);
} }
Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Buffer_Type amax_buf, Buffer_Type scale_buf,
Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type amax_out_buf, Result_Type wkspace_buf,
Result_Type barrier_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype(wkspace_buf->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype(barrier_buf->element_type());
auto *input = x_buf.untyped_data();
auto *weight = gamma_buf.untyped_data();
auto *bias = beta_buf.untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
auto *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto *mu = mu_buf->untyped_data();
auto *rsigma = rsigma_buf->untyped_data();
auto *amax_out = amax_out_buf->untyped_data();
auto *workspace = wkspace_buf->untyped_data();
auto *barrier = barrier_buf->untyped_data();
NVTE_CHECK(amax_out == amax,
"amax not bound to amax_out in TE/JAX LayerNormForwardFP8 primitive");
auto x_dims = x_buf.dimensions();
auto gamma_dims = gamma_buf.dimensions();
auto x_size = std::accumulate(x_dims.begin(), x_dims.end(), 1, std::multiplies<>());
auto gamma_size = std::accumulate(gamma_dims.begin(), gamma_dims.end(), 1, std::multiplies<>());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
auto wkspace_dims = wkspace_buf->dimensions();
auto barrier_dims = barrier_buf->dimensions();
auto wkspace_size =
std::accumulate(wkspace_dims.begin(), wkspace_dims.end(), 1, std::multiplies<>());
auto barrier_size =
std::accumulate(barrier_dims.begin(), barrier_dims.end(), 1, std::multiplies<>());
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
sm_margin, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
auto *weight = buffers[1]; auto *weight = buffers[1];
...@@ -310,6 +382,85 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -310,6 +382,85 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
dbeta_part_dtype, sm_margin, stream); dbeta_part_dtype, sm_margin, stream);
} }
Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type mu_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf,
Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type dbeta_buf,
Result_Type wkspace_buf, Result_Type barrier_buf,
Result_Type dgamma_part_buf, Result_Type dbeta_part_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype(wkspace_buf->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype(barrier_buf->element_type());
auto dgamma_part_dtype = convert_ffi_datatype_to_te_dtype(dgamma_part_buf->element_type());
auto dbeta_part_dtype = convert_ffi_datatype_to_te_dtype(dbeta_part_buf->element_type());
auto *ograd = dz_buf.untyped_data();
auto *mu = mu_buf.untyped_data();
auto *rsigma = rsigma_buf.untyped_data();
auto *input = x_buf.untyped_data();
auto *weight = gamma_buf.untyped_data();
auto *xgrad = xgrad_buf->untyped_data();
auto *wgrad = wgrad_buf->untyped_data();
auto *dbeta = dbeta_buf->untyped_data();
auto *workspace = wkspace_buf->untyped_data();
auto *barrier = barrier_buf->untyped_data();
auto *dgamma_part = dgamma_part_buf->untyped_data();
auto *dbeta_part = dbeta_part_buf->untyped_data();
auto x_dims = x_buf.dimensions();
auto gamma_dims = gamma_buf.dimensions();
auto x_size = std::accumulate(x_dims.begin(), x_dims.end(), 1, std::multiplies<>());
auto gamma_size = std::accumulate(gamma_dims.begin(), gamma_dims.end(), 1, std::multiplies<>());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
auto wkspace_dims = wkspace_buf->dimensions();
auto barrier_dims = barrier_buf->dimensions();
auto wkspace_size =
std::accumulate(wkspace_dims.begin(), wkspace_dims.end(), 1, std::multiplies<>());
auto barrier_size =
std::accumulate(barrier_dims.begin(), barrier_dims.end(), 1, std::multiplies<>());
auto dgamma_part_dims = dgamma_part_buf->dimensions();
auto dbeta_part_dims = dbeta_part_buf->dimensions();
std::vector<size_t> dgamma_parts_dims_vector(dgamma_part_dims.begin(), dgamma_part_dims.end());
std::vector<size_t> dbeta_parts_dims_vector(dbeta_part_dims.begin(), dbeta_part_dims.end());
Shape dgamma_part_shape, dbeta_part_shape;
dgamma_part_shape.from_vector(dgamma_parts_dims_vector);
dbeta_part_shape.from_vector(dbeta_parts_dims_vector);
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, sm_margin, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // dz
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // mu
.Arg<Buffer_Type>() // rsigma
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // xgrad
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // dbeta
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Ret<Buffer_Type>() // dgamma_part
.Ret<Buffer_Type>() // dbeta_part
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
auto *weight = buffers[1]; auto *weight = buffers[1];
......
...@@ -52,9 +52,15 @@ pybind11::dict Registrations() { ...@@ -52,9 +52,15 @@ pybind11::dict Registrations() {
dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler);
dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler); dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler);
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler);
dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler);
dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler); dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler);
dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler);
dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler);
dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler);
dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler);
return dict; return dict;
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "extensions.h" #include "extensions.h"
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
...@@ -27,6 +28,41 @@ void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t op ...@@ -27,6 +28,41 @@ void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t op
nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream); nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream);
} }
Error_Type QuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type amax_out_buf) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
auto *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX Quantize primitive.");
auto input_dims = input_buf.dimensions();
std::vector<size_t> shape(input_dims.begin(), input_dims.end());
auto input_tensor = TensorWrapper(input, shape, in_dtype);
auto output_tensor = TensorWrapper(output, shape, out_dtype, amax_out, scale, scale_inv);
nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(QuantizeHandler, QuantizeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>(), // amax_out
FFI_CudaGraph_Traits);
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
auto *amax = reinterpret_cast<float *>(buffers[1]); auto *amax = reinterpret_cast<float *>(buffers[1]);
......
...@@ -36,6 +36,38 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o ...@@ -36,6 +36,38 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o
TransposeImpl(input, rows, cols, dtype, stream, output); TransposeImpl(input, rows, cols, dtype, stream, output);
} }
Error_Type TransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf,
int64_t transpose_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
void *input = input_buf.untyped_data();
void *output = output_buf->untyped_data();
auto input_dims = input_buf.dimensions();
if (transpose_axis < 0) transpose_axis += input_dims.size();
auto m = std::accumulate(input_dims.begin(), input_dims.begin() + transpose_axis, 1,
std::multiplies<>());
auto n = std::accumulate(input_dims.begin() + transpose_axis, input_dims.end(), 1,
std::multiplies<>());
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype);
nvte_transpose(input_tensor.data(), output_tensor.data(), stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(TransposeHandler, TransposeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // output
.Attr<int64_t>("transpose_axis"),
FFI_CudaGraph_Traits);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]); float *amax = reinterpret_cast<float *>(buffers[1]);
...@@ -82,7 +114,7 @@ Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -82,7 +114,7 @@ Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto *input_cast = input_cast_buf->untyped_data(); auto *input_cast = input_cast_buf->untyped_data();
auto *input_cast_trans = input_cast_trans_buf->untyped_data(); auto *input_cast_trans = input_cast_trans_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data()); float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
assert(amax == amax_out); NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX CastTranspose primitive.");
if (!use_fp8(out_dtype)) { if (!use_fp8(out_dtype)) {
scale = nullptr; scale = nullptr;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment