Unverified Commit 4d65073f authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[TE/JAX] XLA FFI calls for three cast transpose functions (#1310)



* FFI for some transpose & activation functions
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove comments in transformer_engine/jax/csrc/extensions/activation.cpp
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
Signed-off-by: default avatarHua Huang <huangh1994@outlook.com>

---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Signed-off-by: default avatarHua Huang <huangh1994@outlook.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent d4aa2996
......@@ -22,6 +22,7 @@ from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu
from transformer_engine.jax.cpp_extensions.transpose import (
_jax_transpose,
_jax_cast_transpose,
_jax_dbias_cast_transpose,
)
from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8
from transformer_engine.jax import cpp_extensions as tex
......@@ -504,7 +505,6 @@ class TestActivationLuFP8(TestActivationLu):
scale_inv,
FP8Helper.BWD_DTYPE,
-1,
-2,
self.activation_type,
)
)
......@@ -812,6 +812,34 @@ class TestTranspose:
assert_tree_like_allclose(jax_output, ffi_output)
assert_tree_like_allclose(noffi_output, ffi_output)
@pytest.mark.parametrize(
"out_dtype",
[
pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
],
)
def test_dbias_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype):
amax = jnp.zeros(1, jnp.float32)
scale = jnp.ones(1, jnp.float32)
scale_inv = jnp.ones(1, jnp.float32)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
static_axis_boundary = -1
jax_output = _jax_dbias_cast_transpose(
input, amax, scale, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "0"
noffi_output = tex.dbias_cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "1"
ffi_output = tex.dbias_cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
assert_tree_like_allclose(jax_output, ffi_output)
assert_tree_like_allclose(noffi_output, ffi_output)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize(
......
......@@ -64,6 +64,35 @@ def _jax_cast_transpose(
return casted_output, casted_transposed_output, updated_amax
def _jax_dbias_cast_transpose(
dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
):
"""
JAX native dbias_cast_transpose implementation
"""
casted_dz, cast_transposed_dz, updated_amax = _jax_cast_transpose(
dz,
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
dbias = jnp.sum(
dz,
axis=tuple(
range(
transpose_axis_boundary
if transpose_axis_boundary > 0
else transpose_axis_boundary + dz.ndim
)
),
keepdims=False,
)
dbias = dbias.ravel() # C++ function returns an 1D array for dbias
return casted_dz, cast_transposed_dz, dbias, updated_amax
class TransposePrimitive(BasePrimitive):
"""
Transpose Primitive
......@@ -419,12 +448,7 @@ def cast_transpose(
"""
if not CastTransposePrimitive.enabled():
return _jax_cast_transpose(
x,
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
x, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary
)
return CastTransposePrimitive.outer_primitive.bind(
x,
......@@ -512,45 +536,53 @@ class DBiasCastTransposePrimitive(BasePrimitive):
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
batch_size = reduce(operator.mul, ir_dz_shape[:transpose_axis_boundary])
ir_hidden_size = reduce(operator.mul, ir_dz_shape[transpose_axis_boundary:])
contracted_dz_shape = (batch_size, ir_hidden_size)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_dz_shape = multidim_transpose(
ir_dz_shape, static_axis_boundary, transpose_axis_boundary
)
dbias_shape = (*ir_dz_shape[: static_axis_boundary + 1], ir_hidden_size)
if is_ffi_enabled():
name = "te_dbias_cast_transpose_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={1: 3})(
ctx, dz, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary
)
else:
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
batch_size = reduce(operator.mul, ir_dz_shape[:transpose_axis_boundary])
ir_hidden_size = reduce(operator.mul, ir_dz_shape[transpose_axis_boundary:])
contracted_dz_shape = (batch_size, ir_hidden_size)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_dz_shape = multidim_transpose(
ir_dz_shape, static_axis_boundary, transpose_axis_boundary
)
dbias_shape = (*ir_dz_shape[: static_axis_boundary + 1], ir_hidden_size)
wkspace_aval = ctx.avals_out[-1]
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(ir_dz_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype),
ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
]
operands = [dz, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_dz_shape,
wkspace_aval.shape,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
)
out_types = [
ir.RankedTensorType.get(ir_dz_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype),
ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
]
operands = [dz, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_dz_shape,
wkspace_aval.shape,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
)
out = custom_caller(
DBiasCastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 3}
)
out = custom_caller(
DBiasCastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 3}
)
return out
......@@ -677,26 +709,9 @@ def dbias_cast_transpose(
static_axis_boundary = -1 # means no static axes
if not DBiasCastTransposePrimitive.enabled():
casted_dz, cast_transposed_dz, updated_amax = _jax_cast_transpose(
dz,
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
return _jax_dbias_cast_transpose(
dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
)
dbias = jnp.sum(
dz,
axis=tuple(
range(
transpose_axis_boundary
if transpose_axis_boundary > 0
else transpose_axis_boundary + dz.ndim
)
),
keepdims=False,
)
return casted_dz, cast_transposed_dz, dbias, updated_amax
return DBiasCastTransposePrimitive.outer_primitive.bind(
dz,
......@@ -716,8 +731,8 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
name = "te_dact_lu_dbias_cast_transpose"
multiple_results = True
# out_dtype, static_axis_boundary, transpose_axis_boundary, act_enum
impl_static_args = (5, 6, 7, 8)
# out_dtype, static_axis_boundary, act_enum
impl_static_args = (5, 6, 7)
inner_primitive = None
outer_primitive = None
......@@ -731,7 +746,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
*,
out_dtype,
static_axis_boundary,
transpose_axis_boundary,
act_enum
): # pylint: disable=unused-argument
"""
......@@ -746,7 +760,7 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
ir_hidden_szie = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
assert ir_hidden_szie == gi_hidden_size
t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, transpose_axis_boundary)
t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, -2)
out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
......@@ -779,19 +793,7 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
return out, t_out, dbias, updated_amax_aval
@staticmethod
def lowering(
ctx,
dz,
x,
amax,
scale,
scale_inv,
*,
out_dtype,
static_axis_boundary,
transpose_axis_boundary,
act_enum
):
def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, act_enum):
"""
te_dgated_act_lu_cast_transpose_p lowering rules
"""
......@@ -801,55 +803,67 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
x_batch_size = reduce(operator.mul, x_shape[:-2])
assert dz_batch_szie == x_batch_size
ir_hidden_szie = ir_dz_shape[-1]
contracted_x_shape = (x_batch_size, ir_hidden_szie)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_x_shape = multidim_transpose(
x_shape, static_axis_boundary, transpose_axis_boundary
)
dbias_shape = (*x_shape[: static_axis_boundary + 1], ir_hidden_szie)
if is_ffi_enabled():
name = "te_dact_lu_dbias_cast_transpose_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={2: 3})(
ctx, dz, x, amax, scale, scale_inv, act_enum=int(act_enum)
)
else:
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
x_batch_size = reduce(operator.mul, x_shape[:-2])
assert dz_batch_szie == x_batch_size
ir_hidden_szie = ir_dz_shape[-1]
contracted_x_shape = (x_batch_size, ir_hidden_szie)
wkspace_aval = ctx.avals_out[-1]
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, -2)
dbias_shape = (*x_shape[: static_axis_boundary + 1], ir_hidden_szie)
out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
]
operands = [dz, x, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_x_shape,
wkspace_aval.shape,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
act_enum,
)
wkspace_aval = ctx.avals_out[-1]
out = custom_caller(
DActLuDBiasCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={2: 3},
)
out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
]
operands = [dz, x, amax, scale, scale_inv]
operand_shapes = [
ir_dz_shape,
x_shape,
ir_amax_shape,
ir_scale_shape,
ir_scale_inv_shape,
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_x_shape,
wkspace_aval.shape,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
act_enum,
)
out = custom_caller(
DActLuDBiasCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={2: 3},
)
return out
......@@ -862,7 +876,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
scale_inv,
out_dtype,
static_axis_boundary,
transpose_axis_boundary,
act_enum,
):
"""
......@@ -877,21 +890,12 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
act_enum=act_enum,
)
return out, t_out, dbias, updated_amax
@staticmethod
def batcher(
batched_args,
batch_dims,
*,
out_dtype,
static_axis_boundary,
transpose_axis_boundary,
act_enum
):
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum):
"""
to describe batch rules for vmap
"""
......@@ -901,10 +905,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
dz, x, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims
# Minus batch dim.
transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
transpose_axis_boundary += 1 # Plus batch dim
out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
return (
DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
......@@ -915,7 +915,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=x_bdim,
transpose_axis_boundary=transpose_axis_boundary,
act_enum=act_enum,
),
out_bdims,
......@@ -925,7 +924,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
def infer_sharding_from_operands(
out_dtype,
static_axis_boundary,
transpose_axis_boundary,
act_enum,
mesh,
arg_infos,
......@@ -934,7 +932,7 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
del out_dtype, result_infos, act_enum
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
......@@ -946,7 +944,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
def partition(
out_dtype,
static_axis_boundary,
transpose_axis_boundary,
act_enum,
mesh,
arg_infos,
......@@ -955,7 +952,7 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
......@@ -981,7 +978,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
act_enum=act_enum,
)
)
......@@ -1003,7 +999,6 @@ def dact_lu_dbias_cast_transpose(
scale_inv: jnp.ndarray,
out_dtype: TEDType,
static_axis_boundary: int,
transpose_axis_boundary: int = -1,
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
......@@ -1017,27 +1012,10 @@ def dact_lu_dbias_cast_transpose(
if not DActLuDBiasCastTransposePrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x)
(dx,) = vjp_func(dz)
casted_dx, cast_transposed_dx, updated_amax = _jax_cast_transpose(
dx,
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
dbias = jnp.squeeze(
jnp.sum(
dx,
axis=tuple(
range(
transpose_axis_boundary
if transpose_axis_boundary > 0
else transpose_axis_boundary + dx.ndim
)
),
)
transpose_axis_boundary = -2
return _jax_dbias_cast_transpose(
dx, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
)
return casted_dx, cast_transposed_dx, dbias, updated_amax
act_type_id = ActivationEnum[activation_type]
return DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
......@@ -1048,7 +1026,6 @@ def dact_lu_dbias_cast_transpose(
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
act_enum=act_type_id,
)
......@@ -1106,47 +1083,59 @@ class DgatedActLuCastTransposePrimitive(BasePrimitive):
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
x_batch_size = reduce(operator.mul, x_shape[:-2])
assert dz_batch_szie == x_batch_size
assert x_shape[-2] == 2 # Linear + GeLU
ir_hidden_szie = ir_dz_shape[-1]
gi_hidden_size = x_shape[-1]
assert ir_hidden_szie == gi_hidden_size
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, -2)
out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [dz, x, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
contracted_x_shape = (x_batch_size, x_shape[-1])
opaque = transformer_engine_jax.pack_common_descriptor(
contracted_x_shape,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
act_enum,
)
if is_ffi_enabled():
name = "te_dgated_act_lu_cast_transpose_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={2: 2})(
ctx, dz, x, amax, scale, scale_inv, act_enum=int(act_enum)
)
else:
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
x_batch_size = reduce(operator.mul, x_shape[:-2])
assert dz_batch_szie == x_batch_size
assert x_shape[-2] == 2 # Linear + GeLU
ir_hidden_szie = ir_dz_shape[-1]
gi_hidden_size = x_shape[-1]
assert ir_hidden_szie == gi_hidden_size
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, -2)
out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [dz, x, amax, scale, scale_inv]
operand_shapes = [
ir_dz_shape,
x_shape,
ir_amax_shape,
ir_scale_shape,
ir_scale_inv_shape,
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
contracted_x_shape = (x_batch_size, x_shape[-1])
opaque = transformer_engine_jax.pack_common_descriptor(
contracted_x_shape,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
act_enum,
)
out = custom_caller(
DgatedActLuCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={2: 2},
)
out = custom_caller(
DgatedActLuCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={2: 2},
)
return out
......
......@@ -155,27 +155,29 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(TransposeHandler);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CastTransposeHandler);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CastTransposeHandler);
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasCastTransposeHandler);
// Activation
size_t get_activation_len(NVTE_Activation_Type activation_enum);
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuFP8Handler);
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuHandler);
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
......@@ -184,9 +186,13 @@ pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasCastTransposeHandler);
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DGatedActLuCastTransposeHandler);
// Normalization
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
......
......@@ -373,7 +373,7 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
;
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
......@@ -422,6 +422,107 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o
}
}
Error_Type DActLuDBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type dbias_buf, Result_Type amax_out_buf,
Result_Type workspace_buf, int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
void *workspace = workspace_buf->untyped_data();
NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DActLuDBiasCastTranspose primitive.");
if (!use_fp8(out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto input_dims = input_buf.dimensions();
auto act_input_dims = act_input_buf.dimensions();
auto workspace_dims = workspace_buf->dimensions();
// m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
// n = ir_dz_shape[-1], ir_dz_shape == input_dims
auto input_ranks = input_dims.size();
auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
auto n = product(input_dims, input_ranks - 1, input_ranks);
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto act_input_tensor = TensorWrapper(act_input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
switch (act_type) {
case NVTE_Activation_Type::GELU:
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasCastTransposeHandler, DActLuDBiasCastTransposeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act_input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // output_trans
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // workspace
.Attr<int64_t>("act_enum"),
FFI_CudaGraph_Traits);
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
......@@ -444,7 +545,7 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
;
auto input_shape = desc.shape.to_vector();
auto act_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
......@@ -484,5 +585,88 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o
}
}
Error_Type DGatedActLuCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type amax_out_buf, int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DGatedActLuCastTranspose primitive.");
if (!use_fp8(out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto input_dims = input_buf.dimensions();
auto act_input_dims = act_input_buf.dimensions();
auto act_input_ranks = act_input_dims.size();
auto m = product(act_input_dims, 0, act_input_ranks - 2);
auto n = product(act_input_dims, act_input_ranks - 1, act_input_ranks);
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto output_trans_shape = std::vector<size_t>{n * 2, m};
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
switch (act_type) {
case NVTE_Activation_Type::GEGLU:
nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DGatedActLuCastTransposeHandler, DGatedActLuCastTransposeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act_input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // output_trans
.Ret<Buffer_Type>() // amax_out
.Attr<int64_t>("act_enum"),
FFI_CudaGraph_Traits);
} // namespace jax
} // namespace transformer_engine
......@@ -55,11 +55,16 @@ pybind11::dict Registrations() {
// Transpose
dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler);
dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler);
dict["te_dbias_cast_transpose_ffi"] = EncapsulateFFI(DBiasCastTransposeHandler);
// Activation
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler);
dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler);
dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler);
dict["te_dact_lu_dbias_cast_transpose_ffi"] =
EncapsulateFunction(DActLuDBiasCastTransposeHandler);
dict["te_dgated_act_lu_cast_transpose_ffi"] =
EncapsulateFunction(DGatedActLuCastTransposeHandler);
// Quantization
dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler);
......
......@@ -100,18 +100,18 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
Result_Type input_cast_buf, Result_Type input_cast_trans_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type amax_out_buf, int64_t transpose_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(input_cast_buf->element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *input_cast = input_cast_buf->untyped_data();
auto *input_cast_trans = input_cast_trans_buf->untyped_data();
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX CastTranspose primitive.");
......@@ -126,15 +126,15 @@ Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto m = product(input_dims, 0, transpose_axis);
auto n = product(input_dims, transpose_axis, input_dims.size());
auto input_shape = std::vector<size_t>{m, n};
auto input_trans_shape = std::vector<size_t>{n, m};
auto output_shape = input_shape;
auto output_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto input_cast_tensor =
TensorWrapper(input_cast, input_shape, out_dtype, amax_out, scale, scale_inv);
auto input_cast_trans_tensor =
TensorWrapper(input_cast_trans, input_trans_shape, out_dtype, amax_out, scale, scale_inv);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv);
nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(), input_cast_trans_tensor.data(),
nvte_cast_transpose(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(),
stream);
return ffi_with_cuda_error_check();
}
......@@ -146,8 +146,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CastTransposeHandler, CastTransposeFFI,
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // input_cast
.Ret<Buffer_Type>() // input_cast_trans
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // output_trans
.Ret<Buffer_Type>() // amax_out
.Attr<int64_t>("transpose_axis"),
FFI_CudaGraph_Traits);
......@@ -213,5 +213,70 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
dbias_tensor.data(), workspace.data(), stream);
}
Error_Type DBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type dbias_buf, Result_Type amax_out_buf,
Result_Type workspace_buf, int64_t transpose_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
auto *input = input_buf.untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
void *workspace = workspace_buf->untyped_data();
NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DBiasCastTranspose primitive.");
if (!use_fp8(out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto input_dims = input_buf.dimensions();
auto workspace_dims = workspace_buf->dimensions();
if (transpose_axis < 0) transpose_axis += input_dims.size();
auto m = product(input_dims, 0, transpose_axis);
auto n = product(input_dims, transpose_axis, input_dims.size());
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasCastTransposeHandler, DBiasCastTransposeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // output_trans
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // workspace
.Attr<int64_t>("transpose_axis"),
FFI_CudaGraph_Traits);
} // namespace jax
} // namespace transformer_engine
......@@ -516,7 +516,6 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-2,
activation_type=activation_type,
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment