"vscode:/vscode.git/clone" did not exist on "311490a720951f322977d811eacea685a623b5a1"
Unverified Commit 4d65073f authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[TE/JAX] XLA FFI calls for three cast transpose functions (#1310)



* FFI for some transpose & activation functions
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove comments in transformer_engine/jax/csrc/extensions/activation.cpp
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
Signed-off-by: default avatarHua Huang <huangh1994@outlook.com>

---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Signed-off-by: default avatarHua Huang <huangh1994@outlook.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent d4aa2996
...@@ -22,6 +22,7 @@ from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu ...@@ -22,6 +22,7 @@ from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu
from transformer_engine.jax.cpp_extensions.transpose import ( from transformer_engine.jax.cpp_extensions.transpose import (
_jax_transpose, _jax_transpose,
_jax_cast_transpose, _jax_cast_transpose,
_jax_dbias_cast_transpose,
) )
from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8 from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8
from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax import cpp_extensions as tex
...@@ -504,7 +505,6 @@ class TestActivationLuFP8(TestActivationLu): ...@@ -504,7 +505,6 @@ class TestActivationLuFP8(TestActivationLu):
scale_inv, scale_inv,
FP8Helper.BWD_DTYPE, FP8Helper.BWD_DTYPE,
-1, -1,
-2,
self.activation_type, self.activation_type,
) )
) )
...@@ -812,6 +812,34 @@ class TestTranspose: ...@@ -812,6 +812,34 @@ class TestTranspose:
assert_tree_like_allclose(jax_output, ffi_output) assert_tree_like_allclose(jax_output, ffi_output)
assert_tree_like_allclose(noffi_output, ffi_output) assert_tree_like_allclose(noffi_output, ffi_output)
@pytest.mark.parametrize(
"out_dtype",
[
pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
],
)
def test_dbias_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype):
amax = jnp.zeros(1, jnp.float32)
scale = jnp.ones(1, jnp.float32)
scale_inv = jnp.ones(1, jnp.float32)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
static_axis_boundary = -1
jax_output = _jax_dbias_cast_transpose(
input, amax, scale, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "0"
noffi_output = tex.dbias_cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "1"
ffi_output = tex.dbias_cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
assert_tree_like_allclose(jax_output, ffi_output)
assert_tree_like_allclose(noffi_output, ffi_output)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -64,6 +64,35 @@ def _jax_cast_transpose( ...@@ -64,6 +64,35 @@ def _jax_cast_transpose(
return casted_output, casted_transposed_output, updated_amax return casted_output, casted_transposed_output, updated_amax
def _jax_dbias_cast_transpose(
dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
):
"""
JAX native dbias_cast_transpose implementation
"""
casted_dz, cast_transposed_dz, updated_amax = _jax_cast_transpose(
dz,
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
dbias = jnp.sum(
dz,
axis=tuple(
range(
transpose_axis_boundary
if transpose_axis_boundary > 0
else transpose_axis_boundary + dz.ndim
)
),
keepdims=False,
)
dbias = dbias.ravel() # C++ function returns an 1D array for dbias
return casted_dz, cast_transposed_dz, dbias, updated_amax
class TransposePrimitive(BasePrimitive): class TransposePrimitive(BasePrimitive):
""" """
Transpose Primitive Transpose Primitive
...@@ -419,12 +448,7 @@ def cast_transpose( ...@@ -419,12 +448,7 @@ def cast_transpose(
""" """
if not CastTransposePrimitive.enabled(): if not CastTransposePrimitive.enabled():
return _jax_cast_transpose( return _jax_cast_transpose(
x, x, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
) )
return CastTransposePrimitive.outer_primitive.bind( return CastTransposePrimitive.outer_primitive.bind(
x, x,
...@@ -512,6 +536,12 @@ class DBiasCastTransposePrimitive(BasePrimitive): ...@@ -512,6 +536,12 @@ class DBiasCastTransposePrimitive(BasePrimitive):
assert amax_aval.dtype == jnp.float32 assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_dbias_cast_transpose_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={1: 3})(
ctx, dz, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary
)
else:
ir_dz_type = ir.RankedTensorType(dz.type) ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape ir_dz_shape = ir_dz_type.shape
batch_size = reduce(operator.mul, ir_dz_shape[:transpose_axis_boundary]) batch_size = reduce(operator.mul, ir_dz_shape[:transpose_axis_boundary])
...@@ -535,7 +565,9 @@ class DBiasCastTransposePrimitive(BasePrimitive): ...@@ -535,7 +565,9 @@ class DBiasCastTransposePrimitive(BasePrimitive):
ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype), ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype),
ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type), ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
] ]
operands = [dz, amax, scale, scale_inv] operands = [dz, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
...@@ -677,26 +709,9 @@ def dbias_cast_transpose( ...@@ -677,26 +709,9 @@ def dbias_cast_transpose(
static_axis_boundary = -1 # means no static axes static_axis_boundary = -1 # means no static axes
if not DBiasCastTransposePrimitive.enabled(): if not DBiasCastTransposePrimitive.enabled():
casted_dz, cast_transposed_dz, updated_amax = _jax_cast_transpose( return _jax_dbias_cast_transpose(
dz, dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
dbias = jnp.sum(
dz,
axis=tuple(
range(
transpose_axis_boundary
if transpose_axis_boundary > 0
else transpose_axis_boundary + dz.ndim
) )
),
keepdims=False,
)
return casted_dz, cast_transposed_dz, dbias, updated_amax
return DBiasCastTransposePrimitive.outer_primitive.bind( return DBiasCastTransposePrimitive.outer_primitive.bind(
dz, dz,
...@@ -716,8 +731,8 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ...@@ -716,8 +731,8 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
name = "te_dact_lu_dbias_cast_transpose" name = "te_dact_lu_dbias_cast_transpose"
multiple_results = True multiple_results = True
# out_dtype, static_axis_boundary, transpose_axis_boundary, act_enum # out_dtype, static_axis_boundary, act_enum
impl_static_args = (5, 6, 7, 8) impl_static_args = (5, 6, 7)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -731,7 +746,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ...@@ -731,7 +746,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
*, *,
out_dtype, out_dtype,
static_axis_boundary, static_axis_boundary,
transpose_axis_boundary,
act_enum act_enum
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
""" """
...@@ -746,7 +760,7 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ...@@ -746,7 +760,7 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
ir_hidden_szie = dz_aval.shape[-1] ir_hidden_szie = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1] gi_hidden_size = x_aval.shape[-1]
assert ir_hidden_szie == gi_hidden_size assert ir_hidden_szie == gi_hidden_size
t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, transpose_axis_boundary) t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, -2)
out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype) out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
...@@ -779,19 +793,7 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ...@@ -779,19 +793,7 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
return out, t_out, dbias, updated_amax_aval return out, t_out, dbias, updated_amax_aval
@staticmethod @staticmethod
def lowering( def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, act_enum):
ctx,
dz,
x,
amax,
scale,
scale_inv,
*,
out_dtype,
static_axis_boundary,
transpose_axis_boundary,
act_enum
):
""" """
te_dgated_act_lu_cast_transpose_p lowering rules te_dgated_act_lu_cast_transpose_p lowering rules
""" """
...@@ -801,6 +803,12 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ...@@ -801,6 +803,12 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
assert amax_aval.dtype == jnp.float32 assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_dact_lu_dbias_cast_transpose_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={2: 3})(
ctx, dz, x, amax, scale, scale_inv, act_enum=int(act_enum)
)
else:
ir_dz_type = ir.RankedTensorType(dz.type) ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type) x_type = ir.RankedTensorType(x.type)
...@@ -817,9 +825,7 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ...@@ -817,9 +825,7 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
ir_amax_shape = ir_amax_type.shape ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape ir_scale_inv_shape = ir_amax_shape
transposed_x_shape = multidim_transpose( transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, -2)
x_shape, static_axis_boundary, transpose_axis_boundary
)
dbias_shape = (*x_shape[: static_axis_boundary + 1], ir_hidden_szie) dbias_shape = (*x_shape[: static_axis_boundary + 1], ir_hidden_szie)
wkspace_aval = ctx.avals_out[-1] wkspace_aval = ctx.avals_out[-1]
...@@ -829,10 +835,18 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ...@@ -829,10 +835,18 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype), ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type), ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
] ]
operands = [dz, x, amax, scale, scale_inv] operands = [dz, x, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [
ir_dz_shape,
x_shape,
ir_amax_shape,
ir_scale_shape,
ir_scale_inv_shape,
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor( opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_x_shape, contracted_x_shape,
...@@ -862,7 +876,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ...@@ -862,7 +876,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
scale_inv, scale_inv,
out_dtype, out_dtype,
static_axis_boundary, static_axis_boundary,
transpose_axis_boundary,
act_enum, act_enum,
): ):
""" """
...@@ -877,21 +890,12 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ...@@ -877,21 +890,12 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
scale_inv, scale_inv,
out_dtype=out_dtype, out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary, static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
act_enum=act_enum, act_enum=act_enum,
) )
return out, t_out, dbias, updated_amax return out, t_out, dbias, updated_amax
@staticmethod @staticmethod
def batcher( def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum):
batched_args,
batch_dims,
*,
out_dtype,
static_axis_boundary,
transpose_axis_boundary,
act_enum
):
""" """
to describe batch rules for vmap to describe batch rules for vmap
""" """
...@@ -901,10 +905,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ...@@ -901,10 +905,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
dz, x, amax, scale, scale_inv = batched_args dz, x, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims x_bdim, _, amax_bdim, _, _ = batch_dims
# Minus batch dim.
transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
transpose_axis_boundary += 1 # Plus batch dim
out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
return ( return (
DActLuDBiasCastTransposePrimitive.outer_primitive.bind( DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
...@@ -915,7 +915,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ...@@ -915,7 +915,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
scale_inv, scale_inv,
out_dtype=out_dtype, out_dtype=out_dtype,
static_axis_boundary=x_bdim, static_axis_boundary=x_bdim,
transpose_axis_boundary=transpose_axis_boundary,
act_enum=act_enum, act_enum=act_enum,
), ),
out_bdims, out_bdims,
...@@ -925,7 +924,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ...@@ -925,7 +924,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
def infer_sharding_from_operands( def infer_sharding_from_operands(
out_dtype, out_dtype,
static_axis_boundary, static_axis_boundary,
transpose_axis_boundary,
act_enum, act_enum,
mesh, mesh,
arg_infos, arg_infos,
...@@ -934,7 +932,7 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ...@@ -934,7 +932,7 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
del out_dtype, result_infos, act_enum del out_dtype, result_infos, act_enum
x_spec = get_padded_spec(arg_infos[1]) x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding( dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1]) mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
...@@ -946,7 +944,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ...@@ -946,7 +944,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
def partition( def partition(
out_dtype, out_dtype,
static_axis_boundary, static_axis_boundary,
transpose_axis_boundary,
act_enum, act_enum,
mesh, mesh,
arg_infos, arg_infos,
...@@ -955,7 +952,7 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ...@@ -955,7 +952,7 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
del result_infos del result_infos
x_spec = get_padded_spec(arg_infos[1]) x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding( dbias_shaprding = NamedSharding(
...@@ -981,7 +978,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ...@@ -981,7 +978,6 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
scale_inv, scale_inv,
out_dtype=out_dtype, out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary, static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
act_enum=act_enum, act_enum=act_enum,
) )
) )
...@@ -1003,7 +999,6 @@ def dact_lu_dbias_cast_transpose( ...@@ -1003,7 +999,6 @@ def dact_lu_dbias_cast_transpose(
scale_inv: jnp.ndarray, scale_inv: jnp.ndarray,
out_dtype: TEDType, out_dtype: TEDType,
static_axis_boundary: int, static_axis_boundary: int,
transpose_axis_boundary: int = -1,
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
""" """
...@@ -1017,27 +1012,10 @@ def dact_lu_dbias_cast_transpose( ...@@ -1017,27 +1012,10 @@ def dact_lu_dbias_cast_transpose(
if not DActLuDBiasCastTransposePrimitive.enabled(): if not DActLuDBiasCastTransposePrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x) _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x)
(dx,) = vjp_func(dz) (dx,) = vjp_func(dz)
casted_dx, cast_transposed_dx, updated_amax = _jax_cast_transpose( transpose_axis_boundary = -2
dx, return _jax_dbias_cast_transpose(
scale, dx, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
) )
dbias = jnp.squeeze(
jnp.sum(
dx,
axis=tuple(
range(
transpose_axis_boundary
if transpose_axis_boundary > 0
else transpose_axis_boundary + dx.ndim
)
),
)
)
return casted_dx, cast_transposed_dx, dbias, updated_amax
act_type_id = ActivationEnum[activation_type] act_type_id = ActivationEnum[activation_type]
return DActLuDBiasCastTransposePrimitive.outer_primitive.bind( return DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
...@@ -1048,7 +1026,6 @@ def dact_lu_dbias_cast_transpose( ...@@ -1048,7 +1026,6 @@ def dact_lu_dbias_cast_transpose(
scale_inv, scale_inv,
out_dtype=out_dtype, out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary, static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
act_enum=act_type_id, act_enum=act_type_id,
) )
...@@ -1106,6 +1083,12 @@ class DgatedActLuCastTransposePrimitive(BasePrimitive): ...@@ -1106,6 +1083,12 @@ class DgatedActLuCastTransposePrimitive(BasePrimitive):
assert amax_aval.dtype == jnp.float32 assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_dgated_act_lu_cast_transpose_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={2: 2})(
ctx, dz, x, amax, scale, scale_inv, act_enum=int(act_enum)
)
else:
ir_dz_type = ir.RankedTensorType(dz.type) ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type) x_type = ir.RankedTensorType(x.type)
...@@ -1130,7 +1113,13 @@ class DgatedActLuCastTransposePrimitive(BasePrimitive): ...@@ -1130,7 +1113,13 @@ class DgatedActLuCastTransposePrimitive(BasePrimitive):
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
] ]
operands = [dz, x, amax, scale, scale_inv] operands = [dz, x, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [
ir_dz_shape,
x_shape,
ir_amax_shape,
ir_scale_shape,
ir_scale_inv_shape,
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
contracted_x_shape = (x_batch_size, x_shape[-1]) contracted_x_shape = (x_batch_size, x_shape[-1])
opaque = transformer_engine_jax.pack_common_descriptor( opaque = transformer_engine_jax.pack_common_descriptor(
......
...@@ -155,27 +155,29 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(TransposeHandler); ...@@ -155,27 +155,29 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(TransposeHandler);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CastTransposeHandler);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype); DType in_dtype, DType out_dtype);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CastTransposeHandler);
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasCastTransposeHandler);
// Activation // Activation
size_t get_activation_len(NVTE_Activation_Type activation_enum); size_t get_activation_len(NVTE_Activation_Type activation_enum);
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuFP8Handler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuFP8Handler);
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuHandler);
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
...@@ -184,9 +186,13 @@ pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_ ...@@ -184,9 +186,13 @@ pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasCastTransposeHandler);
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DGatedActLuCastTransposeHandler);
// Normalization // Normalization
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
......
...@@ -373,7 +373,7 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o ...@@ -373,7 +373,7 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum); auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
;
auto input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n}; auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n};
...@@ -422,6 +422,107 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o ...@@ -422,6 +422,107 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o
} }
} }
Error_Type DActLuDBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type dbias_buf, Result_Type amax_out_buf,
Result_Type workspace_buf, int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
void *workspace = workspace_buf->untyped_data();
NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DActLuDBiasCastTranspose primitive.");
if (!use_fp8(out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto input_dims = input_buf.dimensions();
auto act_input_dims = act_input_buf.dimensions();
auto workspace_dims = workspace_buf->dimensions();
// m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
// n = ir_dz_shape[-1], ir_dz_shape == input_dims
auto input_ranks = input_dims.size();
auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
auto n = product(input_dims, input_ranks - 1, input_ranks);
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto act_input_tensor = TensorWrapper(act_input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
switch (act_type) {
case NVTE_Activation_Type::GELU:
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasCastTransposeHandler, DActLuDBiasCastTransposeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act_input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // output_trans
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // workspace
.Attr<int64_t>("act_enum"),
FFI_CudaGraph_Traits);
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) { size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
...@@ -444,7 +545,7 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o ...@@ -444,7 +545,7 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum); auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
;
auto input_shape = desc.shape.to_vector(); auto input_shape = desc.shape.to_vector();
auto act_input_shape = std::vector<size_t>{m, n * 2}; auto act_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2}; auto output_shape = std::vector<size_t>{m, n * 2};
...@@ -484,5 +585,88 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o ...@@ -484,5 +585,88 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o
} }
} }
Error_Type DGatedActLuCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type amax_out_buf, int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DGatedActLuCastTranspose primitive.");
if (!use_fp8(out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto input_dims = input_buf.dimensions();
auto act_input_dims = act_input_buf.dimensions();
auto act_input_ranks = act_input_dims.size();
auto m = product(act_input_dims, 0, act_input_ranks - 2);
auto n = product(act_input_dims, act_input_ranks - 1, act_input_ranks);
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto output_trans_shape = std::vector<size_t>{n * 2, m};
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
switch (act_type) {
case NVTE_Activation_Type::GEGLU:
nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DGatedActLuCastTransposeHandler, DGatedActLuCastTransposeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act_input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // output_trans
.Ret<Buffer_Type>() // amax_out
.Attr<int64_t>("act_enum"),
FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -55,11 +55,16 @@ pybind11::dict Registrations() { ...@@ -55,11 +55,16 @@ pybind11::dict Registrations() {
// Transpose // Transpose
dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler); dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler);
dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler); dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler);
dict["te_dbias_cast_transpose_ffi"] = EncapsulateFFI(DBiasCastTransposeHandler);
// Activation // Activation
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler);
dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler); dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler);
dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler); dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler);
dict["te_dact_lu_dbias_cast_transpose_ffi"] =
EncapsulateFunction(DActLuDBiasCastTransposeHandler);
dict["te_dgated_act_lu_cast_transpose_ffi"] =
EncapsulateFunction(DGatedActLuCastTransposeHandler);
// Quantization // Quantization
dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler);
......
...@@ -100,18 +100,18 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size ...@@ -100,18 +100,18 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
Result_Type input_cast_buf, Result_Type input_cast_trans_buf, Result_Type output_buf, Result_Type output_trans_buf,
Result_Type amax_out_buf, int64_t transpose_axis) { Result_Type amax_out_buf, int64_t transpose_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(input_cast_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data(); auto *input = input_buf.untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data()); float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data()); float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data()); float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *input_cast = input_cast_buf->untyped_data(); auto *output = output_buf->untyped_data();
auto *input_cast_trans = input_cast_trans_buf->untyped_data(); auto *output_trans = output_trans_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data()); float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX CastTranspose primitive."); NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX CastTranspose primitive.");
...@@ -126,15 +126,15 @@ Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -126,15 +126,15 @@ Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto m = product(input_dims, 0, transpose_axis); auto m = product(input_dims, 0, transpose_axis);
auto n = product(input_dims, transpose_axis, input_dims.size()); auto n = product(input_dims, transpose_axis, input_dims.size());
auto input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
auto input_trans_shape = std::vector<size_t>{n, m}; auto output_shape = input_shape;
auto output_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto input_cast_tensor = auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
TensorWrapper(input_cast, input_shape, out_dtype, amax_out, scale, scale_inv); auto output_trans_tensor =
auto input_cast_trans_tensor = TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv);
TensorWrapper(input_cast_trans, input_trans_shape, out_dtype, amax_out, scale, scale_inv);
nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(), input_cast_trans_tensor.data(), nvte_cast_transpose(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(),
stream); stream);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
...@@ -146,8 +146,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CastTransposeHandler, CastTransposeFFI, ...@@ -146,8 +146,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CastTransposeHandler, CastTransposeFFI,
.Arg<Buffer_Type>() // amax .Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv .Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // input_cast .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // input_cast_trans .Ret<Buffer_Type>() // output_trans
.Ret<Buffer_Type>() // amax_out .Ret<Buffer_Type>() // amax_out
.Attr<int64_t>("transpose_axis"), .Attr<int64_t>("transpose_axis"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
...@@ -213,5 +213,70 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -213,5 +213,70 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
dbias_tensor.data(), workspace.data(), stream); dbias_tensor.data(), workspace.data(), stream);
} }
Error_Type DBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type dbias_buf, Result_Type amax_out_buf,
Result_Type workspace_buf, int64_t transpose_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
auto *input = input_buf.untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
void *workspace = workspace_buf->untyped_data();
NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DBiasCastTranspose primitive.");
if (!use_fp8(out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto input_dims = input_buf.dimensions();
auto workspace_dims = workspace_buf->dimensions();
if (transpose_axis < 0) transpose_axis += input_dims.size();
auto m = product(input_dims, 0, transpose_axis);
auto n = product(input_dims, transpose_axis, input_dims.size());
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasCastTransposeHandler, DBiasCastTransposeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // output_trans
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // workspace
.Attr<int64_t>("transpose_axis"),
FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -516,7 +516,6 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -516,7 +516,6 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dactivation_lu_scale_inv, dactivation_lu_scale_inv,
bwd_dtype, bwd_dtype,
static_axis_boundary=-1, static_axis_boundary=-1,
transpose_axis_boundary=-2,
activation_type=activation_type, activation_type=activation_type,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment