"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "1ec33ae1191ae6644365155f8e8f618145c44cd7"
Unverified Commit 18c2234c authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[JAX] XLA Custom Calls with FFI for FusedAttnFwd, Quantize, Transpose,...


[JAX] XLA Custom Calls with FFI for FusedAttnFwd, Quantize, Transpose, ActLuFP8, LayerNormForwardFP8FFI, and LayerNormBackwardFFI (#1263)

* Add TransposeFFI, test passed
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Add ActLuFP8FFI; fix TransposeFFI
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Add QuantizeFFI
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Add FusedAttnForwardFFI and some unit tests
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor fix
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Add LayerNormForwardFP8FFI & LayerNormBackwardFFI
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revise FusedAttnForwardFFI()
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add FFI_CudaGraph_Traits

All tests passed, ready for merge
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Bug fix for FFI data type mismatch

Also add a safeguard on the entrance to FFI function
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 20c75295
......@@ -3,9 +3,8 @@
# See LICENSE for license information.
from contextlib import nullcontext
import functools
import operator
from typing import Callable, List, Sequence, Union
import os
import jax
import jax.numpy as jnp
......@@ -14,12 +13,17 @@ import pytest
from jax import jit, value_and_grad
from flax import linen as nn
from utils import assert_allclose
from utils import assert_allclose, assert_tree_like_allclose
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp
from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu
from transformer_engine.jax.cpp_extensions.transpose import (
_jax_transpose,
_jax_cast_transpose,
)
from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8
from transformer_engine.jax import cpp_extensions as tex
......@@ -746,3 +750,102 @@ class TestNorm:
assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE)
if beta is not None:
assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.parametrize(
"in_dtype",
[
pytest.param(jnp.float32, id="input_float32"),
pytest.param(jnp.float16, id="input_float16"),
pytest.param(jnp.bfloat16, id="input_bfloat16"),
],
)
@pytest.mark.parametrize(
"input_shape, transpose_axis",
[
pytest.param((16, 16), 1, id="(16, 16)-1"),
pytest.param((256, 128), 1, id="(256, 128)-1"),
pytest.param((128, 512), 1, id="(128, 512)-1"),
pytest.param((64, 16, 4, 256), 1, id="(64, 16, 4, 256)-1"),
pytest.param((64, 16, 4, 256), 2, id="(64, 16, 4, 256)-2"),
pytest.param((64, 16, 4, 256), 3, id="(64, 16, 4, 256)-3"),
],
)
class TestTranspose:
def test_transpose(self, in_dtype, input_shape, transpose_axis):
key = jax.random.PRNGKey(0)
input_tensor = jax.random.uniform(key, input_shape, in_dtype)
static_axis_boundary = -1
jax_output = _jax_transpose(input_tensor, static_axis_boundary, transpose_axis)
os.environ["NVTE_JAX_WITH_FFI"] = "0"
noffi_output = tex.transpose(input_tensor, static_axis_boundary, transpose_axis)
os.environ["NVTE_JAX_WITH_FFI"] = "1"
ffi_output = tex.transpose(input_tensor, static_axis_boundary, transpose_axis)
assert_allclose(jax_output, noffi_output)
assert_allclose(noffi_output, ffi_output)
@pytest.mark.parametrize(
"out_dtype",
[
pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
],
)
def test_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype):
amax = jnp.zeros(1, jnp.float32)
scale = jnp.ones(1, jnp.float32)
scale_inv = jnp.ones(1, jnp.float32)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
static_axis_boundary = -1
jax_output = _jax_cast_transpose(
input, scale, amax, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "0"
noffi_output = tex.cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "1"
ffi_output = tex.cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
assert_tree_like_allclose(jax_output, ffi_output)
assert_tree_like_allclose(noffi_output, ffi_output)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize(
"input_shape",
[
pytest.param((256, 128), id="(256, 128)"),
pytest.param((128, 512, 8), id="(128, 512, 8)"),
],
)
@pytest.mark.parametrize(
"in_dtype",
[
pytest.param(jnp.float32, id="input_float32"),
pytest.param(jnp.float16, id="input_float16"),
pytest.param(jnp.bfloat16, id="input_bfloat16"),
],
)
@pytest.mark.parametrize(
"out_dtype",
[
pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
],
)
def test_quantize(input_shape, in_dtype, out_dtype):
amax = jnp.zeros(1, jnp.float32)
scale = jnp.ones(1, jnp.float32)
scale_inv = jnp.ones(1, jnp.float32)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
jax_output = _jax_cast_fp8(input, scale, amax, out_dtype)
os.environ["NVTE_JAX_WITH_FFI"] = "0"
noffi_output = tex.cast_fp8(input, amax, scale, scale_inv, out_dtype)
os.environ["NVTE_JAX_WITH_FFI"] = "1"
ffi_output = tex.cast_fp8(input, amax, scale, scale_inv, out_dtype)
assert_tree_like_allclose(jax_output, ffi_output)
assert_tree_like_allclose(noffi_output, ffi_output)
......@@ -67,7 +67,7 @@ BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_HIDDEN_DROPOUT: 0,
_KEY_OF_ATTENTION_DROPOUT: 0,
_KEY_OF_ATTENTION_DROPOUT: 0.0,
_KEY_OF_INTERMEDIATE_DROPOUT: 0,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
_KEY_OF_LAYERNORM_TYPE: "layernorm",
......
......@@ -383,37 +383,43 @@ class ActLuFp8Primitive(BasePrimitive):
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
hidden_size = ir_x_shape[-1]
batch_shape = ir_x_shape[:-2]
batch_size = reduce(operator.mul, batch_shape)
out_shape = batch_shape + [hidden_size]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor(
(batch_size, hidden_size),
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
act_enum,
)
if is_ffi_enabled():
name = "te_act_lu_fp8_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={1: 1})(
ctx, x, amax, scale, scale_inv, act_enum=act_enum
)
else:
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
out = custom_caller(
ActLuFp8Primitive.name, args, opaque, False, operand_output_aliases={1: 1}
)
hidden_size = ir_x_shape[-1]
batch_shape = ir_x_shape[:-2]
batch_size = reduce(operator.mul, batch_shape)
out_shape = batch_shape + [hidden_size]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor(
(batch_size, hidden_size),
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
act_enum,
)
out = custom_caller(
ActLuFp8Primitive.name, args, opaque, False, operand_output_aliases={1: 1}
)
return out
......
......@@ -15,6 +15,7 @@ from jax import dtypes, lax
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import (
......@@ -33,6 +34,7 @@ from .misc import (
te_dtype_to_jax_dtype,
get_padded_spec,
get_cudnn_version,
is_ffi_enabled,
)
from ..sharding import (
global_mesh_resource,
......@@ -352,14 +354,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
"""
Fused attention fwd lowering rules
"""
operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, seed]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
......@@ -376,31 +370,85 @@ class FusedAttnFwdPrimitive(BasePrimitive):
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
if is_ffi_enabled():
name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)(
ctx,
q,
k,
v,
bias,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
input_batch=input_batch,
bias_batch=bias_batch,
q_max_seqlen=q_max_seqlen,
kv_max_seqlen=kv_max_seqlen,
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
head_dim=head_dim,
max_segments_per_seq=config.max_segments_per_seq,
wkspace_size=wkspace_aval.size,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type),
mask_type=int(config.attn_mask_type),
qkv_layout=int(config.qkv_layout),
dtype=int(jax_dtype_to_te_dtype(q_aval.dtype)),
wkspace_dtype=int(jax_dtype_to_te_dtype(wkspace_aval.dtype)),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0],
window_size_right=config.window_size[1],
)
else:
operands = [
q,
k,
v,
bias,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
......
......@@ -9,6 +9,7 @@ import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType
......@@ -20,6 +21,7 @@ from .misc import (
check_valid_batch_dims,
jax_dtype_to_te_dtype,
jax_dtype_to_ir_dtype,
is_ffi_enabled,
)
from ..sharding import all_reduce_max_along_all_axes_except_PP
......@@ -84,30 +86,36 @@ class CastFP8Primitive(BasePrimitive):
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
out_types = [
ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor(
ir_x_shape, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype)
)
if is_ffi_enabled():
name = "te_quantize_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={1: 1})(
ctx, x, amax, scale, scale_inv
)
else:
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
out_types = [
ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor(
ir_x_shape, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype)
)
out = custom_caller(
CastFP8Primitive.name, args, opaque, False, operand_output_aliases={1: 1}
)
out = custom_caller(
CastFP8Primitive.name, args, opaque, False, operand_output_aliases={1: 1}
)
return out
......
......@@ -102,32 +102,36 @@ class TransposePrimitive(BasePrimitive):
jnp.float8_e5m2,
]
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype)
if static_axis_boundary >= 0:
for i in range(static_axis_boundary + 1):
assert ir_x_shape[i] == 1
if is_ffi_enabled():
name = "te_transpose_ffi"
out = ffi.ffi_lowering(name)(ctx, x, transpose_axis=transpose_axis_boundary)
else:
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype)
if static_axis_boundary >= 0:
for i in range(static_axis_boundary + 1):
assert ir_x_shape[i] == 1
transposed_x_shape = multidim_transpose(
ir_x_shape, static_axis_boundary, transpose_axis_boundary
)
transposed_x_shape = multidim_transpose(
ir_x_shape, static_axis_boundary, transpose_axis_boundary
)
out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)]
operands = [x]
operand_shapes = [ir_x_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)]
operands = [x]
operand_shapes = [ir_x_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
te_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
contracted_x_shape = (
reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]),
)
opaque = transformer_engine_jax.pack_common_descriptor(
contracted_x_shape, te_dtype, te_dtype
)
te_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
contracted_x_shape = (
reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]),
)
opaque = transformer_engine_jax.pack_common_descriptor(
contracted_x_shape, te_dtype, te_dtype
)
out = custom_caller(TransposePrimitive.name, args, opaque, False)
out = custom_caller(TransposePrimitive.name, args, opaque, False)
return out
......
......@@ -151,6 +151,8 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(TransposeHandler);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
......@@ -172,6 +174,8 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuFP8Handler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuHandler);
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
......@@ -195,6 +199,8 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormForwardFP8Handler);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype,
bool is_layer_norm, bool zero_centered_gamma,
......@@ -202,6 +208,8 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormBackwardHandler);
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......@@ -212,6 +220,8 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(QuantizeHandler);
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// Softmax
......@@ -253,6 +263,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
......
......@@ -153,6 +153,51 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op
act_enum, act_len);
}
Error_Type ActLuFP8FFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type amax_out_buf, int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX ActLuFP8 primitive.");
if (!use_fp8(out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto input_dims = input_buf.dimensions();
auto m = std::accumulate(input_dims.begin(), input_dims.end() - 2, 1, std::multiplies<>());
auto n = input_dims.back();
auto act_len = input_dims.end()[-2];
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
ActLuImpl(input, m, n, in_dtype, out_dtype, scale, stream, scale_inv, amax_out, output, act_type,
act_len);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuFP8Handler, ActLuFP8FFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // amax_out
.Attr<int64_t>("act_enum"),
FFI_CudaGraph_Traits);
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *act_input = buffers[1];
......
......@@ -30,18 +30,13 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
- common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812
- common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359
*/
void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack,
const CustomCallFusedAttnDescriptor *desc,
void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch,
const size_t bias_batch, const size_t attn_heads,
const size_t bias_heads, const size_t q_max_seqlen,
const size_t kv_max_seqlen, DType dtype,
NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
void *softmax_buf, void *rng_state_buf = nullptr,
void *bias_buf = nullptr) {
auto input_batch = desc->input_batch;
auto bias_batch = desc->bias_batch;
auto attn_heads = desc->attn_heads;
auto bias_heads = desc->bias_heads;
auto q_max_seqlen = desc->q_max_seqlen;
auto kv_max_seqlen = desc->kv_max_seqlen;
// all backends need softmax but expect different shapes/dtypes
// start with the max512 sequence length softmax shape/dtype and correct later
tensor_pack->size = 1;
......@@ -49,7 +44,7 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack,
softmax_aux->data.dptr = softmax_buf;
softmax_aux->data.shape =
std::vector<size_t>{input_batch, attn_heads, q_max_seqlen, kv_max_seqlen};
softmax_aux->data.dtype = desc->dtype;
softmax_aux->data.dtype = dtype;
// arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
......@@ -69,7 +64,7 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack,
bias_aux->data.dptr = bias_buf;
bias_aux->data.shape =
std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
bias_aux->data.dtype = desc->dtype;
bias_aux->data.dtype = dtype;
}
}
}
......@@ -82,22 +77,25 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack,
TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()?
*/
void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack,
const CustomCallFusedAttnDescriptor *desc,
void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch,
const size_t bias_batch, const size_t attn_heads,
const size_t bias_heads, const size_t q_max_seqlen,
const size_t kv_max_seqlen, DType dtype,
NVTE_Fused_Attn_Backend backend, void *softmax_buf,
void *rng_state_buf, void *bias_buf) {
// Backward calls put everything into the tensor pack for every backend
// so we set dummy bias_type and backend choices here to follow the correct code path
auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
PrepareFusedAttnForwardAuxTensors(tensor_pack, desc, dummy_bias_type, dummy_backend, softmax_buf,
rng_state_buf, bias_buf);
PrepareFusedAttnForwardAuxTensors(tensor_pack, input_batch, bias_batch, attn_heads, bias_heads,
q_max_seqlen, kv_max_seqlen, dtype, dummy_bias_type,
dummy_backend, softmax_buf, rng_state_buf, bias_buf);
// correct softmax shape for max512 sequence length kernel
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
softmax_aux->data.shape.at(3) = desc->kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks}
softmax_aux->data.dtype = desc->dtype;
softmax_aux->data.shape.at(3) = kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks}
softmax_aux->data.dtype = dtype;
}
}
......@@ -190,7 +188,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
auto qkv_layout = descriptor.qkv_layout;
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
......@@ -279,8 +276,9 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
/* Auxiliary tensors (to be propagated to the backward pass later) */
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
softmax_aux);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type,
backend, softmax_aux);
/* cuDNN workspace */
auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
......@@ -335,6 +333,201 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
nvte_tensor_pack_destroy(&aux_output_tensors);
}
Error_Type FusedAttnForwardFFI(
cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf,
Buffer_Type bias_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, Buffer_Type seed_buf,
Result_Type output_buf, Result_Type softmax_aux_buf, Result_Type rng_state_buf,
Result_Type workspace_buf, int64_t input_batch_, int64_t bias_batch_, int64_t q_max_seqlen_,
int64_t kv_max_seqlen_, int64_t attn_heads_, int64_t num_gqa_groups_, int64_t bias_heads_,
int64_t head_dim_, int64_t max_segments_per_seq_, int64_t wkspace_size_, double scaling_factor_,
double dropout_probability_, int64_t bias_type_, int64_t mask_type_, int64_t qkv_layout_,
int64_t dtype_, int64_t wkspace_dtype_, bool is_training, bool deterministic,
int64_t window_size_left, int64_t window_size_right) {
/* Descriptor data type conversion */
size_t input_batch = static_cast<size_t>(input_batch_);
size_t bias_batch = static_cast<size_t>(bias_batch_);
size_t q_max_seqlen = static_cast<size_t>(q_max_seqlen_);
size_t kv_max_seqlen = static_cast<size_t>(kv_max_seqlen_);
size_t attn_heads = static_cast<size_t>(attn_heads_);
size_t num_gqa_groups = static_cast<size_t>(num_gqa_groups_);
size_t bias_heads = static_cast<size_t>(bias_heads_);
size_t head_dim = static_cast<size_t>(head_dim_);
size_t max_segments_per_seq = static_cast<size_t>(max_segments_per_seq_);
size_t wkspace_size = static_cast<size_t>(wkspace_size_);
float scaling_factor = static_cast<float>(scaling_factor_);
float dropout_probability = static_cast<float>(dropout_probability_);
NVTE_Bias_Type bias_type = static_cast<NVTE_Bias_Type>(bias_type_);
NVTE_Mask_Type mask_type = static_cast<NVTE_Mask_Type>(mask_type_);
NVTE_QKV_Layout qkv_layout = static_cast<NVTE_QKV_Layout>(qkv_layout_);
DType dtype = static_cast<DType>(dtype_);
DType wkspace_dtype = static_cast<DType>(wkspace_dtype_);
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
/* Input buffers from XLA */
/* q, k, v are parsed later for different qkv_layout */
void *bias = bias_buf.untyped_data();
void *q_cu_seqlens = q_cu_seqlens_buf.untyped_data();
void *kv_cu_seqlens = kv_cu_seqlens_buf.untyped_data();
void *q_seq_offsets = is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr;
void *k_seq_offsets = is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr;
void *seed = seed_buf.untyped_data();
/* Output buffer from XLA */
void *output = output_buf->untyped_data();
void *softmax_aux = softmax_aux_buf->untyped_data();
void *rng_state = rng_state_buf->untyped_data();
void *workspace = workspace_buf->untyped_data();
/* Input tensors */
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments
if (is_ragged) {
auto cudnn_runtime_version = cudnnGetVersion();
if (cudnn_runtime_version >= 90300) {
num_segments = input_batch * max_segments_per_seq;
} else {
// workspace can be reused here as it is not used with cuDNN graph at the same time
size_t runtime_num_segments_q =
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
size_t runtime_num_segments_kv =
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
num_segments = runtime_num_segments_q;
}
auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim;
cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
}
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto q_seq_offsets_tensor =
TensorWrapper(q_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto k_seq_offsets_tensor =
TensorWrapper(k_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
/* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto o_tensor = TensorWrapper(output, o_shape, dtype);
/* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type,
backend, softmax_aux);
/* cuDNN workspace */
auto workspace_tensor =
TensorWrapper(workspace, std::vector<size_t>{wkspace_size}, wkspace_dtype);
/* Call the underlying NVTE API */
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv = q_buf.untyped_data();
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, is_training, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = q_buf.untyped_data();
auto kv = k_buf.untyped_data();
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = q_buf.untyped_data();
auto k = k_buf.untyped_data();
auto v = v_buf.untyped_data();
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
nvte_tensor_pack_destroy(&aux_output_tensors);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // q
.Arg<Buffer_Type>() // k
.Arg<Buffer_Type>() // v
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // q_cu_seqlens
.Arg<Buffer_Type>() // kv_cu_seqlens
.Arg<Buffer_Type>() // q_seq_offsets
.Arg<Buffer_Type>() // k_seq_offsets
.Arg<Buffer_Type>() // seed_buf
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // softmax_aux
.Ret<Buffer_Type>() // rng_state
.Ret<Buffer_Type>() // workspace
.Attr<int64_t>("input_batch")
.Attr<int64_t>("bias_batch")
.Attr<int64_t>("q_max_seqlen")
.Attr<int64_t>("kv_max_seqlen")
.Attr<int64_t>("attn_heads")
.Attr<int64_t>("num_gqa_groups")
.Attr<int64_t>("bias_heads")
.Attr<int64_t>("head_dim")
.Attr<int64_t>("max_segments_per_seq")
.Attr<int64_t>("wkspace_size")
.Attr<double>("scaling_factor")
.Attr<double>("dropout_probability")
.Attr<int64_t>("bias_type")
.Attr<int64_t>("mask_type")
.Attr<int64_t>("qkv_layout")
.Attr<int64_t>("dtype")
.Attr<int64_t>("wkspace_dtype")
.Attr<bool>("is_training")
.Attr<bool>("deterministic")
.Attr<int64_t>("window_size_left")
.Attr<int64_t>("window_size_right"),
FFI_CudaGraph_Traits);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
......@@ -523,8 +716,9 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias);
/* cuDNN workspace */
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
......
......@@ -15,12 +15,21 @@ namespace jax {
// For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
switch (type) {
case xla::ffi::DataType::F16:
return DType::kFloat16;
case xla::ffi::DataType::U8:
return DType::kByte;
break;
case xla::ffi::DataType::S32:
return DType::kInt32;
break;
case xla::ffi::DataType::S64:
return DType::kInt64;
break;
case xla::ffi::DataType::F32:
return DType::kFloat32;
break;
case xla::ffi::DataType::F16:
return DType::kFloat16;
break;
case xla::ffi::DataType::BF16:
return DType::kBFloat16;
break;
......
......@@ -237,6 +237,78 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
sm_margin, stream);
}
Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Buffer_Type amax_buf, Buffer_Type scale_buf,
Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type amax_out_buf, Result_Type wkspace_buf,
Result_Type barrier_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype(wkspace_buf->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype(barrier_buf->element_type());
auto *input = x_buf.untyped_data();
auto *weight = gamma_buf.untyped_data();
auto *bias = beta_buf.untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
auto *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto *mu = mu_buf->untyped_data();
auto *rsigma = rsigma_buf->untyped_data();
auto *amax_out = amax_out_buf->untyped_data();
auto *workspace = wkspace_buf->untyped_data();
auto *barrier = barrier_buf->untyped_data();
NVTE_CHECK(amax_out == amax,
"amax not bound to amax_out in TE/JAX LayerNormForwardFP8 primitive");
auto x_dims = x_buf.dimensions();
auto gamma_dims = gamma_buf.dimensions();
auto x_size = std::accumulate(x_dims.begin(), x_dims.end(), 1, std::multiplies<>());
auto gamma_size = std::accumulate(gamma_dims.begin(), gamma_dims.end(), 1, std::multiplies<>());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
auto wkspace_dims = wkspace_buf->dimensions();
auto barrier_dims = barrier_buf->dimensions();
auto wkspace_size =
std::accumulate(wkspace_dims.begin(), wkspace_dims.end(), 1, std::multiplies<>());
auto barrier_size =
std::accumulate(barrier_dims.begin(), barrier_dims.end(), 1, std::multiplies<>());
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
sm_margin, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
......@@ -310,6 +382,85 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
dbeta_part_dtype, sm_margin, stream);
}
Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type mu_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf,
Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type dbeta_buf,
Result_Type wkspace_buf, Result_Type barrier_buf,
Result_Type dgamma_part_buf, Result_Type dbeta_part_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype(wkspace_buf->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype(barrier_buf->element_type());
auto dgamma_part_dtype = convert_ffi_datatype_to_te_dtype(dgamma_part_buf->element_type());
auto dbeta_part_dtype = convert_ffi_datatype_to_te_dtype(dbeta_part_buf->element_type());
auto *ograd = dz_buf.untyped_data();
auto *mu = mu_buf.untyped_data();
auto *rsigma = rsigma_buf.untyped_data();
auto *input = x_buf.untyped_data();
auto *weight = gamma_buf.untyped_data();
auto *xgrad = xgrad_buf->untyped_data();
auto *wgrad = wgrad_buf->untyped_data();
auto *dbeta = dbeta_buf->untyped_data();
auto *workspace = wkspace_buf->untyped_data();
auto *barrier = barrier_buf->untyped_data();
auto *dgamma_part = dgamma_part_buf->untyped_data();
auto *dbeta_part = dbeta_part_buf->untyped_data();
auto x_dims = x_buf.dimensions();
auto gamma_dims = gamma_buf.dimensions();
auto x_size = std::accumulate(x_dims.begin(), x_dims.end(), 1, std::multiplies<>());
auto gamma_size = std::accumulate(gamma_dims.begin(), gamma_dims.end(), 1, std::multiplies<>());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
auto wkspace_dims = wkspace_buf->dimensions();
auto barrier_dims = barrier_buf->dimensions();
auto wkspace_size =
std::accumulate(wkspace_dims.begin(), wkspace_dims.end(), 1, std::multiplies<>());
auto barrier_size =
std::accumulate(barrier_dims.begin(), barrier_dims.end(), 1, std::multiplies<>());
auto dgamma_part_dims = dgamma_part_buf->dimensions();
auto dbeta_part_dims = dbeta_part_buf->dimensions();
std::vector<size_t> dgamma_parts_dims_vector(dgamma_part_dims.begin(), dgamma_part_dims.end());
std::vector<size_t> dbeta_parts_dims_vector(dbeta_part_dims.begin(), dbeta_part_dims.end());
Shape dgamma_part_shape, dbeta_part_shape;
dgamma_part_shape.from_vector(dgamma_parts_dims_vector);
dbeta_part_shape.from_vector(dbeta_parts_dims_vector);
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, sm_margin, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // dz
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // mu
.Arg<Buffer_Type>() // rsigma
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // xgrad
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // dbeta
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Ret<Buffer_Type>() // dgamma_part
.Ret<Buffer_Type>() // dbeta_part
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
......
......@@ -52,9 +52,15 @@ pybind11::dict Registrations() {
dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler);
dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler);
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler);
dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler);
dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler);
dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler);
dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler);
dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler);
dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler);
return dict;
}
......
......@@ -6,6 +6,7 @@
#include "extensions.h"
#include "transformer_engine/cast.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
namespace jax {
......@@ -27,6 +28,41 @@ void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t op
nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream);
}
Error_Type QuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type amax_out_buf) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
auto *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX Quantize primitive.");
auto input_dims = input_buf.dimensions();
std::vector<size_t> shape(input_dims.begin(), input_dims.end());
auto input_tensor = TensorWrapper(input, shape, in_dtype);
auto output_tensor = TensorWrapper(output, shape, out_dtype, amax_out, scale, scale_inv);
nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(QuantizeHandler, QuantizeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>(), // amax_out
FFI_CudaGraph_Traits);
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *amax = reinterpret_cast<float *>(buffers[1]);
......
......@@ -36,6 +36,38 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o
TransposeImpl(input, rows, cols, dtype, stream, output);
}
Error_Type TransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf,
int64_t transpose_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
void *input = input_buf.untyped_data();
void *output = output_buf->untyped_data();
auto input_dims = input_buf.dimensions();
if (transpose_axis < 0) transpose_axis += input_dims.size();
auto m = std::accumulate(input_dims.begin(), input_dims.begin() + transpose_axis, 1,
std::multiplies<>());
auto n = std::accumulate(input_dims.begin() + transpose_axis, input_dims.end(), 1,
std::multiplies<>());
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype);
nvte_transpose(input_tensor.data(), output_tensor.data(), stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(TransposeHandler, TransposeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // output
.Attr<int64_t>("transpose_axis"),
FFI_CudaGraph_Traits);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
......@@ -82,7 +114,7 @@ Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto *input_cast = input_cast_buf->untyped_data();
auto *input_cast_trans = input_cast_trans_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
assert(amax == amax_out);
NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX CastTranspose primitive.");
if (!use_fp8(out_dtype)) {
scale = nullptr;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment