"vscode:/vscode.git/clone" did not exist on "48cf1e413c42b29909077afe21c7b9e57996a1cf"
Unverified Commit aad4e173 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Generalizing Activation Primitives (#810)



* templated primitives and respective C++ functions
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fixes for LayerNormMLP, tests in test_custom_compute all passed
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* added default arg for pybind get_workspace_size funcs
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fixes for TestTransFormer with non-gated act tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* renamed gelu to act
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* improved enum implementation, avoid using magic numbers
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Exposed C++ ActivationEnum to python side
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Changed error messages
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* changed conditional check on input shape for dbias_cast_transpose
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* changed dtype (tol) for bias grad tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fixes so that layer_norm_fp8_mlp can take bias = None
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Set bias = None in flax modules
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 2045a426
......@@ -194,8 +194,8 @@ class TestFP8Dot:
b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
else:
b1 = jax.random.normal(subkeys[3], (0,), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (0,), jnp.bfloat16)
b1 = None
b2 = None
init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
init_fp8_metas_amax = jnp.zeros(
......@@ -300,19 +300,19 @@ class TestFP8Dot:
assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
if use_bias:
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=jnp.bfloat16)
assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
jnp.asarray(ref_b2_grad, np.float32),
dtype=jnp.bfloat16)
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
@pytest.fixture(name="random_inputs")
......@@ -341,13 +341,14 @@ class TestActivationLu:
def primitive_func(self, inputs):
return jnp.mean(activation_lu(inputs, activation_type = self.activation_type))
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
@pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize('activation_type', [('gelu',),
('gelu', 'linear'),
('silu',),
('silu', 'linear')])
def test_activation_lu(self, random_inputs, activation_type):
x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1)
self.activation_type = activation_type
value_n_grad_primitive_func = jit(
......@@ -355,8 +356,6 @@ class TestActivationLu:
prim_out, (prim_grad,) = value_n_grad_primitive_func(x)
ref_out, (ref_grad,) = self.ref_func(x, activation_type)
""" prim_grad, = prim_grad """
""" ref_grad, = ref_grad """
assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
......@@ -372,7 +371,7 @@ class TestActivationLuFP8(TestActivationLu):
activation_type = self.activation_type))
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
@pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize('activation_type', [('gelu',),
('gelu', 'linear'),
('silu',),
......@@ -384,6 +383,7 @@ class TestActivationLuFP8(TestActivationLu):
self.activation_type = activation_type
x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1)
value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0, 1, 2, 3, 4, 5,)))
......
......@@ -529,11 +529,12 @@ void cast_transpose_dbias(const Tensor &input,
Tensor *dbias,
Tensor *workspace,
cudaStream_t stream) {
// TODO
// CheckInputTensor(input, "cast_transpose_dbias_input");
// CheckOutputTensor(*cast_output, "cast_output");
// CheckOutputTensor(*transposed_output, "transposed_output");
// CheckOutputTensor(*dbias, "dbias");
if (workspace->data.dptr != nullptr) {
CheckInputTensor(input, "cast_transpose_dbias_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
}
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
......
......@@ -4,7 +4,7 @@
"""JAX te custom call"""
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Tuple
from typing import Tuple, Sequence, Union, Callable
from functools import partial, reduce
import operator
import os
......@@ -27,6 +27,7 @@ from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend
from transformer_engine_jax import NVTE_Activation_Enum
from .sharding import all_reduce_max_along_all_axes_except_PP
from .sharding import all_reduce_sum_along_dp_fsdp
......@@ -124,6 +125,14 @@ def _check_valid_batch_dims(bdims):
f"but got {dim=}"
ActivationEnum = {
('gelu',): NVTE_Activation_Enum.GELU,
('gelu', 'linear'): NVTE_Activation_Enum.GEGLU,
('silu',): NVTE_Activation_Enum.SILU,
('silu', 'linear'): NVTE_Activation_Enum.SWIGLU
}
class BasePrimitive(metaclass=ABCMeta):
"""
jax primitive
......@@ -2556,244 +2565,28 @@ def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda
is_training=is_training)
class GeluPrimitive(BasePrimitive):
class ActLuPrimitive(BasePrimitive):
"""
Gelu Froward Primitive
Activation Forward Primitive
"""
name = "te_gelu"
name = "te_act_lu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
impl_static_args = (1,)
@staticmethod
def abstract(x_aval):
def abstract(x_aval, *, act_enum): # pylint: disable=unused-argument
"""
gated_gelu abstract
act_lu abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_aval = core.raise_to_shaped(x_aval)
return out_aval
@staticmethod
def lowering(ctx, x):
"""
gated_gelu lowering rules
"""
(x_aval,) = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
out_shape = ir_x_shape
out_types = [
ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
]
operands = [x]
operand_shapes = [ir_x_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
hidden_size = ir_x_shape[-1]
batch_size = reduce(operator.mul, ir_x_shape[:-1])
in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype,
in_dtype)
out = custom_caller(GeluPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(x):
assert GeluPrimitive.inner_primitive is not None
out = GeluPrimitive.inner_primitive.bind(x)
return out
@staticmethod
def batcher(batched_args, batch_dims):
"""
gated_gelu batcher
"""
_check_valid_batch_dims(batch_dims)
assert GeluPrimitive.outer_primitive is not None
inputs, = batched_args
inputs_bdim, = batch_dims
out_bdims = inputs_bdim
return GeluPrimitive.outer_primitive.bind(inputs), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
"""
gated_gelu infer_sharding_from_operands
"""
del result_infos # Unused.
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
return out_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
"""
gated_gelu partitioning
"""
del result_infos
x_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
impl = GeluPrimitive.impl
return mesh, impl, out_sharding, arg_shardings
register_primitive(GeluPrimitive)
def gelu(inputs: jnp.ndarray) -> jnp.ndarray:
"""
gelu wrapper
Return geglu(inputs)
Assume inputs has two dimensions shape and the memory layout is (N..., H)
"""
return GeluPrimitive.outer_primitive.bind(inputs)
class DGeluPrimitive(BasePrimitive):
"""
Dgated Gelu Primitive
"""
name = "te_dgelu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
@staticmethod
def abstract(dz_aval, x_aval):
"""
dgelu abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert dz_aval.shape == x_aval.shape
out_aval = core.raise_to_shaped(x_aval)
return out_aval
@staticmethod
def lowering(ctx, dz, x):
"""
dgelu lowering rules
"""
in_aval, gi_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gi_aval.dtype == in_aval.dtype
ir_in_type = ir.RankedTensorType(dz.type)
ir_in_shape = ir_in_type.shape
gi_type = ir.RankedTensorType(x.type)
gi_shape = gi_type.shape
assert ir_in_shape == gi_shape
ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
i_hidden_size = ir_in_shape[-1]
out_dtype = ir_in_type.element_type
out_shape = gi_shape
out_types = [
ir.RankedTensorType.get(out_shape, out_dtype),
]
operands = [dz, x]
operand_shapes = [ir_in_shape, gi_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size),
in_dtype, in_dtype)
out = custom_caller(DGeluPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(dz, x):
"""
dgelu implementation
"""
assert DGeluPrimitive.inner_primitive is not None
dx = DGeluPrimitive.inner_primitive.bind(dz, x)
return dx
@staticmethod
def batcher(batched_args, batch_dims):
"""
dgelu batcher
"""
_check_valid_batch_dims(batch_dims)
assert DGeluPrimitive.outer_primitive is not None
dz, x = batched_args
_, x_bdim = batch_dims
out_bdims = x_bdim
return DGeluPrimitive.outer_primitive.bind(dz, x), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
"""
dgelu infer_sharding_from_operands
"""
del result_infos # Unused.
gelu_out_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*gelu_out_spec))
return dx_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
"""
dgelu partition
"""
del result_infos
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding
impl = DGeluPrimitive.impl
return mesh, impl, out_shardings, arg_shardings
register_primitive(DGeluPrimitive)
def dgelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray:
"""
dgelu fusion wrapper
Return dgeglu(inputs)
"""
return DGeluPrimitive.outer_primitive.bind(inputs, gelu_inputs)
class GatedGeluPrimitive(BasePrimitive):
"""
Gated Gelu Froward Primitive
"""
name = "te_gated_gelu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
@staticmethod
def abstract(x_aval):
"""
gated_gelu abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
x_shape = x_aval.shape
assert x_shape[-2] == 2 # Assume x in (....., 2, hidden)
assert (x_shape[-2] == 2 or x_shape[-2] == 1)
hidden_size = x_shape[-1]
batch_shapes = x_shape[:-2]
x_shape = x_aval.shape
out_aval = core.raise_to_shaped(x_aval)
out_shape = (batch_shapes) + (hidden_size,)
out_aval = out_aval.update(shape=out_shape, dtype=dtype)
......@@ -2801,9 +2594,9 @@ class GatedGeluPrimitive(BasePrimitive):
return out_aval
@staticmethod
def lowering(ctx, x):
def lowering(ctx, x, *, act_enum):
"""
gated_gelu lowering rules
act_lu lowering rules
"""
(x_aval,) = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
......@@ -2821,100 +2614,101 @@ class GatedGeluPrimitive(BasePrimitive):
hidden_size = ir_x_shape[-1]
batch_size = reduce(operator.mul, ir_x_shape[:-2])
in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype,
in_dtype)
opaque = transformer_engine_jax.pack_common_descriptor(
(batch_size, hidden_size), in_dtype, in_dtype, act_enum)
out = custom_caller(GatedGeluPrimitive.name, args, opaque, False)
out = custom_caller(ActLuPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(x):
assert GatedGeluPrimitive.inner_primitive is not None
out = GatedGeluPrimitive.inner_primitive.bind(x)
def impl(x, act_enum):
assert ActLuPrimitive.inner_primitive is not None
out = ActLuPrimitive.inner_primitive.bind(x, act_enum=act_enum)
return out
@staticmethod
def batcher(batched_args, batch_dims):
def batcher(batched_args, batch_dims, *, act_enum):
"""
gated_gelu batcher
act_lu batcher
"""
_check_valid_batch_dims(batch_dims)
assert GatedGeluPrimitive.outer_primitive is not None
assert ActLuPrimitive.outer_primitive is not None
inputs, = batched_args
inputs_bdim, = batch_dims
out_bdims = inputs_bdim
return GatedGeluPrimitive.outer_primitive.bind(inputs), out_bdims
return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_enum), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos):
"""
gated_gelu infer_sharding_from_operands
act_lu infer_sharding_from_operands
"""
del result_infos # Unused.
del result_infos, act_enum # Unused.
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
return out_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
def partition(act_enum, mesh, arg_infos, result_infos):
"""
gated_gelu partitioning
act_lu partitioning
"""
del result_infos
del result_infos, act_enum
x_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
impl = GatedGeluPrimitive.impl
impl = ActLuPrimitive.impl
return mesh, impl, out_sharding, arg_shardings
register_primitive(GatedGeluPrimitive)
register_primitive(ActLuPrimitive)
def gated_gelu(inputs: jnp.ndarray) -> jnp.ndarray:
def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]) -> jnp.ndarray:
"""
gated gelu wrapper
Return FP8(geglu(inputs))
Assume inputs has two dimensions shape and the memory layout is (N, 2, H)
act_lu wrapper
Return act_lu(inputs)
Input shape: (N, 1, H) for non-gated activations
(N, 2, H) for gated activations
"""
return GatedGeluPrimitive.outer_primitive.bind(inputs)
act_type_id = ActivationEnum[activation_type]
return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id)
class DgatedGeluPrimitive(BasePrimitive):
class DActLuPrimitive(BasePrimitive):
"""
Dgated Gelu Primitive
Dgated ActLu Primitive
"""
name = "te_dgated_gelu"
name = "te_dact_lu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
impl_static_args = (2,)
@staticmethod
def abstract(dz_aval, x_aval):
def abstract(dz_aval, x_aval, *, act_enum): # pylint: disable=unused-argument
"""
dgated_gelu abstract
dact_lu abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
for axis in range(len(dz_aval.shape) - 1):
assert dz_aval.shape[axis] == x_aval.shape[axis]
assert x_aval.shape[-2] == 2 # Assume x in (....., 2, hidden)
assert (x_aval.shape[-2] == 2 or x_aval.shape[-2] == 1)
i_hidden_size = dz_aval.shape[-1]
g_hidden_size = x_aval.shape[-1]
assert i_hidden_size == g_hidden_size
out_aval = core.raise_to_shaped(x_aval)
return out_aval
@staticmethod
def lowering(ctx, dz, x):
def lowering(ctx, dz, x, *, act_enum):
"""
dgated_gelu lowering rules
dact_lu lowering rules
"""
in_aval, gi_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
......@@ -2942,66 +2736,68 @@ class DgatedGeluPrimitive(BasePrimitive):
in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size),
in_dtype, in_dtype)
in_dtype, in_dtype, act_enum)
out = custom_caller(DgatedGeluPrimitive.name, args, opaque, False)
out = custom_caller(DActLuPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(dz, x):
def impl(dz, x, act_enum):
"""
dgated_gelu implementation
dact_lu implementation
"""
assert DgatedGeluPrimitive.inner_primitive is not None
dx = DgatedGeluPrimitive.inner_primitive.bind(dz, x)
assert DActLuPrimitive.inner_primitive is not None
dx = DActLuPrimitive.inner_primitive.bind(dz, x, act_enum=act_enum)
return dx
@staticmethod
def batcher(batched_args, batch_dims):
def batcher(batched_args, batch_dims, *, act_enum):
"""
dgated_gelu batcher
dact_lu batcher
"""
_check_valid_batch_dims(batch_dims)
assert DgatedGeluPrimitive.outer_primitive is not None
assert DActLuPrimitive.outer_primitive is not None
dz, x = batched_args
_, x_bdim = batch_dims
out_bdims = x_bdim
return DgatedGeluPrimitive.outer_primitive.bind(dz, x), out_bdims
return DActLuPrimitive.outer_primitive.bind(dz, x, act_enum=act_enum), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos):
"""
dgated_gelu infer_sharding_from_operands
dact_lu infer_sharding_from_operands
"""
del result_infos # Unused.
gelu_out_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*gelu_out_spec))
del result_infos, act_enum # Unused.
act_lu_out_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*act_lu_out_spec))
return dx_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
def partition(act_enum, mesh, arg_infos, result_infos):
"""
dgated_gelu partition
dact_lu partition
"""
del result_infos
del result_infos, act_enum
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding
impl = DgatedGeluPrimitive.impl
impl = DActLuPrimitive.impl
return mesh, impl, out_shardings, arg_shardings
register_primitive(DgatedGeluPrimitive)
register_primitive(DActLuPrimitive)
def dgated_gelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray:
def dact_lu(inputs: jnp.ndarray, act_lu_inputs: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]]) -> jnp.ndarray:
"""
dgated_gelu fusion wrapper
Return dgeglu(inputs)
dact_lu fusion wrapper
Return dgated_act_lu(inputs)
"""
return DgatedGeluPrimitive.outer_primitive.bind(inputs, gelu_inputs)
act_type_id = ActivationEnum[activation_type]
return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id)
def _normalize_axis_boundary(axis, ndim):
......@@ -3958,20 +3754,21 @@ def rmsnorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, amax: jnp.ndarray, scale
epsilon=epsilon)
class GeluFp8Primitive(BasePrimitive):
class ActLuFp8Primitive(BasePrimitive):
"""
Gelu FP8 Primitive
ActLu FP8 Primitive
"""
name = "te_gelu_fp8"
name = "te_act_lu_fp8"
multiple_results = True
impl_static_args = (4,) #out_dtype
impl_static_args = (4, 5) #out_dtype, act_enum
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype):
def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
act_enum): # pylint: disable=unused-argument
"""
te_gelu_p abstract
te_act_lu_p abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
# Currently only support casting to E4M3 only in C side.
......@@ -3981,15 +3778,19 @@ class GeluFp8Primitive(BasePrimitive):
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
assert (x_aval.shape[-2] == 1 or x_aval.shape[-2] == 2)
hidden_size = x_aval.shape[-1]
batch_shape = x_aval.shape[:-2]
out_shape = (batch_shape) + (hidden_size,)
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return out_aval, updated_amax_aval
@staticmethod
def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, act_enum):
"""
te_gated_gelu_p lowering rules
te_gated_act_lu_p lowering rules
"""
x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
......@@ -4006,8 +3807,9 @@ class GeluFp8Primitive(BasePrimitive):
ir_scale_inv_shape = ir_amax_shape
hidden_size = ir_x_shape[-1]
batch_size = reduce(operator.mul, ir_x_shape[:-1])
out_shape = ir_x_shape
batch_shape = ir_x_shape[:-2]
batch_size = reduce(operator.mul, batch_shape)
out_shape = batch_shape + [hidden_size]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
......@@ -4016,11 +3818,13 @@ class GeluFp8Primitive(BasePrimitive):
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size),
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype))
opaque = transformer_engine_jax.pack_common_descriptor((
batch_size, hidden_size),
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
act_enum)
out = custom_caller(GeluFp8Primitive.name,
out = custom_caller(ActLuFp8Primitive.name,
args,
opaque,
False,
......@@ -4029,55 +3833,58 @@ class GeluFp8Primitive(BasePrimitive):
return out
@staticmethod
def impl(x, amax, scale, scale_inv, out_dtype):
def impl(x, amax, scale, scale_inv, out_dtype, act_enum):
"""
to describe implementation
"""
assert GeluFp8Primitive.inner_primitive is not None
out, updated_amax = GeluFp8Primitive.inner_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
assert ActLuFp8Primitive.inner_primitive is not None
out, updated_amax = ActLuFp8Primitive.inner_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
act_enum=act_enum)
return out, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype):
def batcher(batched_args, batch_dims, *, out_dtype, act_enum):
"""
to describe batch rules for vmap
"""
_check_valid_batch_dims(batch_dims)
assert GeluFp8Primitive.outer_primitive is not None
assert ActLuFp8Primitive.outer_primitive is not None
x, amax, scale, scale_inv = batched_args
x_bdim, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, amax_bdim
return GeluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv,
out_dtype=out_dtype), out_bdims
return ActLuFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv,
out_dtype=out_dtype,
act_enum=act_enum), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
del out_dtype, result_infos
def infer_sharding_from_operands(out_dtype, act_enum, mesh, arg_infos, result_infos):
del out_dtype, result_infos, act_enum
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (out_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, mesh, arg_infos, result_infos):
def partition(out_dtype, act_enum, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (out_sharding, amax_sharding)
def sharded_impl(x, amax, scale, scale_inv):
local_x, local_amax = GeluFp8Primitive.impl(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
local_x, local_amax = ActLuFp8Primitive.impl(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
act_enum=act_enum)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_x, global_updated_amax
......@@ -4085,34 +3892,40 @@ class GeluFp8Primitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(GeluFp8Primitive)
register_primitive(ActLuFp8Primitive)
def gelu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
def act_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
out_dtype: jnp.dtype, activation_type: Sequence[Union[str, Callable]]
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
gated gelu wrapper
Return FP8(geglu(x))
act wrapper
Return FP8(act_lu(x))
Input shape: (N, 1, H) for non-gated activations
(N, 2, H) for gated activations
"""
return GeluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype)
act_type_id = ActivationEnum[activation_type]
return ActLuFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype,
act_enum = act_type_id)
class DGeluDBiasCastTransposePrimitive(BasePrimitive):
class DActLuDBiasCastTransposePrimitive(BasePrimitive):
"""
DGelu DBias Cast Transpose Primitive
DActLu DBias Cast Transpose Primitive
"""
name = "te_dgelu_dbias_cast_transpose"
name = "te_dact_lu_dbias_cast_transpose"
multiple_results = True
# out_dtype, static_axis_boundary, transpose_axis_boundary
impl_static_args = (5, 6, 7)
# out_dtype, static_axis_boundary, transpose_axis_boundary, act_enum
impl_static_args = (5, 6, 7, 8)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
static_axis_boundary, transpose_axis_boundary):
static_axis_boundary, transpose_axis_boundary,
act_enum): # pylint: disable=unused-argument
"""
te_dgelu_dbais_cast_transpose_p abstract
te_dact_lu_dbais_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
......@@ -4123,7 +3936,8 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
ir_hidden_szie = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
assert ir_hidden_szie == gi_hidden_size
t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, transpose_axis_boundary)
t_shape = _multidim_transpose(x_aval.shape,
static_axis_boundary, transpose_axis_boundary)
out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
......@@ -4146,18 +3960,18 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
@staticmethod
def outer_abstract(*args, **kwargs):
"""
te_dgelu_dbais_cast_transpose_p outer abstract
te_dact_lu_dbais_cast_transpose_p outer abstract
"""
out, t_out, dbias, updated_amax_aval, _ = \
DGeluDBiasCastTransposePrimitive.abstract(*args, **kwargs)
DActLuDBiasCastTransposePrimitive.abstract(*args, **kwargs)
return out, t_out, dbias, updated_amax_aval
@staticmethod
def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary,
transpose_axis_boundary):
transpose_axis_boundary, act_enum):
"""
te_dgated_gelu_cast_transpose_p lowering rules
te_dgated_act_lu_cast_transpose_p lowering rules
"""
dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
......@@ -4169,11 +3983,11 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
assert ir_dz_shape == x_shape
batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
x_batch_size = reduce(operator.mul, x_shape[:-2])
assert dz_batch_szie == x_batch_size
ir_hidden_szie = ir_dz_shape[-1]
contracted_x_shape = (batch_szie, ir_hidden_szie)
contracted_x_shape = (x_batch_size, ir_hidden_szie)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
......@@ -4199,9 +4013,10 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_x_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype))
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
act_enum)
out = custom_caller(DGeluDBiasCastTransposePrimitive.name,
out = custom_caller(DActLuDBiasCastTransposePrimitive.name,
args,
opaque,
False,
......@@ -4211,12 +4026,12 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
@staticmethod
def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary,
transpose_axis_boundary):
transpose_axis_boundary, act_enum):
"""
to describe implementation
"""
assert DGeluDBiasCastTransposePrimitive.inner_primitive is not None
out, t_out, dbias, updated_amax, _ = DGeluDBiasCastTransposePrimitive.inner_primitive.bind(
assert DActLuDBiasCastTransposePrimitive.inner_primitive is not None
out, t_out, dbias, updated_amax, _ = DActLuDBiasCastTransposePrimitive.inner_primitive.bind(
dz,
x,
amax,
......@@ -4224,18 +4039,19 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
transpose_axis_boundary=transpose_axis_boundary,
act_enum=act_enum)
return out, t_out, dbias, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary,
transpose_axis_boundary):
transpose_axis_boundary, act_enum):
"""
to describe batch rules for vmap
"""
del static_axis_boundary
_check_valid_batch_dims(batch_dims)
assert DGeluDBiasCastTransposePrimitive.outer_primitive is not None
assert DActLuDBiasCastTransposePrimitive.outer_primitive is not None
dz, x, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims
......@@ -4244,7 +4060,7 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
transpose_axis_boundary += 1 # Plus batch dim
out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
return DGeluDBiasCastTransposePrimitive.outer_primitive.bind(
return DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
......@@ -4252,12 +4068,13 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=x_bdim,
transpose_axis_boundary=transpose_axis_boundary), out_bdims
transpose_axis_boundary=transpose_axis_boundary,
act_enum=act_enum), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh,
arg_infos, result_infos):
del out_dtype, result_infos
def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary,
act_enum, mesh, arg_infos, result_infos):
del out_dtype, result_infos, act_enum
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
......@@ -4268,8 +4085,8 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)
@staticmethod
def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
result_infos):
def partition(out_dtype, static_axis_boundary, transpose_axis_boundary,
act_enum, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
......@@ -4285,7 +4102,8 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
amax_sharding)
def sharded_impl(dz, x, amax, scale, scale_inv):
local_out, local_t_out, local_dbias, local_amax = DGeluDBiasCastTransposePrimitive.impl(
local_out, local_t_out, local_dbias, local_amax =\
DActLuDBiasCastTransposePrimitive.impl(
dz,
x,
amax,
......@@ -4293,7 +4111,8 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
transpose_axis_boundary=transpose_axis_boundary,
act_enum=act_enum)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_out, local_t_out, global_dbias, global_updated_amax
......@@ -4301,26 +4120,30 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DGeluDBiasCastTransposePrimitive)
register_primitive(DActLuDBiasCastTransposePrimitive)
def dgelu_dbias_cast_transpose(
dz: jnp.ndarray,
x: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType,
static_axis_boundary: int,
transpose_axis_boundary: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
def dact_lu_dbias_cast_transpose(
dz: jnp.ndarray,
x: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType,
static_axis_boundary: int,
transpose_axis_boundary: int = -1,
activation_type: Sequence[Union[str, Callable]] = ('gelu',)
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose dgelu and dbias fusion wrapper
Return FP8(dgeglu(inputs)), dbias
cast transpose dact_lu and dbias fusion wrapper
Return FP8(dact_lu(inputs)), dbias
ONLY support non-gated activation type
"""
if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes
return DGeluDBiasCastTransposePrimitive.outer_primitive.bind(
act_type_id = ActivationEnum[activation_type]
return DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
......@@ -4328,7 +4151,8 @@ def dgelu_dbias_cast_transpose(
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
transpose_axis_boundary=transpose_axis_boundary,
act_enum=act_type_id)
class DBiasCastTransposePrimitive(BasePrimitive):
......@@ -4353,13 +4177,11 @@ class DBiasCastTransposePrimitive(BasePrimitive):
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
gi_hidden_size = dz_aval.shape[-1]
gi_hidden_size = reduce(operator.mul, dz_aval.shape[transpose_axis_boundary:])
t_shape = _multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary)
out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
if dz_aval.shape[-2] == 2:
gi_hidden_size *= 2
dbias_shape = (*dz_aval.shape[:static_axis_boundary + 1], gi_hidden_size)
dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)
......@@ -4398,13 +4220,9 @@ class DBiasCastTransposePrimitive(BasePrimitive):
assert scale_inv_aval.dtype == jnp.float32
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
ir_hidden_szie = ir_dz_shape[-1]
if dz_aval.shape[-2] == 2:
batch_szie = reduce(operator.mul, ir_dz_shape[:-2])
ir_hidden_szie *= 2
else:
batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
contracted_dz_shape = (batch_szie, ir_hidden_szie)
batch_size = reduce(operator.mul, ir_dz_shape[:transpose_axis_boundary])
ir_hidden_size = reduce(operator.mul, ir_dz_shape[transpose_axis_boundary:])
contracted_dz_shape = (batch_size, ir_hidden_size)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
......@@ -4413,7 +4231,7 @@ class DBiasCastTransposePrimitive(BasePrimitive):
ir_scale_inv_shape = ir_amax_shape
transposed_dz_shape = _multidim_transpose(ir_dz_shape, static_axis_boundary,
transpose_axis_boundary)
dbias_shape = (*ir_dz_shape[:static_axis_boundary + 1], ir_hidden_szie)
dbias_shape = (*ir_dz_shape[:static_axis_boundary + 1], ir_hidden_size)
wkspace_aval = ctx.avals_out[-1]
......@@ -4556,1356 +4374,60 @@ def dbias_cast_transpose(
transpose_axis_boundary=transpose_axis_boundary)
class GatedGeluFp8Primitive(BasePrimitive):
class DgatedActLuCastTransposePrimitive(BasePrimitive):
"""
Gated Gelu FP8 Primitive
Dgated ActLu Cast Transpose Primitive
"""
name = "te_gated_gelu_fp8"
name = "te_dgated_act_lu_cast_transpose"
multiple_results = True
impl_static_args = (4,) #out_dtype
impl_static_args = (5, 6, 7) # out_dtype, static_axis_boundary, act_enum
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype):
def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
static_axis_boundary, act_enum): # pylint: disable=unused-argument
"""
te_gated_gelu_p abstract
te_dgated_act_lu_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
# Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert x_aval.shape[-2] == 2 # Linear + GeLU
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
assert x_aval.shape[-2] == 2 # Assume x in (....., 2, hidden)
hidden_size = x_aval.shape[-1]
batch_shape = x_aval.shape[:-2]
out_shape = (batch_shape) + (hidden_size,)
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
ir_hidden_szie = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
assert ir_hidden_szie == gi_hidden_size
t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, -2)
out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return out_aval, updated_amax_aval
return out, t_out, updated_amax_aval
@staticmethod
def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, act_enum):
"""
te_gated_gelu_p lowering rules
te_dgated_act_lu_cast_transpose_p lowering rules
"""
x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
hidden_size = ir_x_shape[-1]
batch_shape = ir_x_shape[:-2]
batch_size = reduce(operator.mul, batch_shape)
out_shape = batch_shape + [hidden_size]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, out_shape[-1]),
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype))
out = custom_caller(GatedGeluFp8Primitive.name,
args,
opaque,
False,
operand_output_aliases={1: 1})
return out
@staticmethod
def impl(x, amax, scale, scale_inv, out_dtype):
"""
to describe implementation
"""
assert GatedGeluFp8Primitive.inner_primitive is not None
out, updated_amax = GatedGeluFp8Primitive.inner_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
return out, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype):
"""
to describe batch rules for vmap
"""
_check_valid_batch_dims(batch_dims)
assert GatedGeluFp8Primitive.outer_primitive is not None
x, amax, scale, scale_inv = batched_args
x_bdim, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, amax_bdim
return GatedGeluFp8Primitive.outer_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (out_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (out_sharding, amax_sharding)
def sharded_impl(x, amax, scale, scale_inv):
local_x, local_amax = GatedGeluFp8Primitive.impl(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_x, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(GatedGeluFp8Primitive)
def gated_gelu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
gated gelu wrapper
Return FP8(geglu(x))
"""
return GatedGeluFp8Primitive.outer_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
class DgatedGeluCastTransposePrimitive(BasePrimitive):
"""
Dgated Gelu Cast Transpose Primitive
"""
name = "te_dgated_gelu_cast_transpose"
multiple_results = True
impl_static_args = (5, 6) # out_dtype, static_axis_boundary
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
static_axis_boundary):
"""
te_dgated_gelu_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert x_aval.shape[-2] == 2 # Linear + GeLU
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_hidden_szie = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
assert ir_hidden_szie == gi_hidden_size
t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, -2)
out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return out, t_out, updated_amax_aval
@staticmethod
def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary):
"""
te_dgated_gelu_cast_transpose_p lowering rules
"""
dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
x_batch_size = reduce(operator.mul, x_shape[:-2])
assert dz_batch_szie == x_batch_size
assert x_shape[-2] == 2 # Linear + GeLU
ir_hidden_szie = ir_dz_shape[-1]
gi_hidden_size = x_shape[-1]
assert ir_hidden_szie == gi_hidden_size
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_x_shape = _multidim_transpose(x_shape, static_axis_boundary, -2)
out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [dz, x, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
contracted_x_shape = (x_batch_size, x_shape[-1])
opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype))
out = custom_caller(DgatedGeluCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={2: 2})
return out
@staticmethod
def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary):
"""
to describe implementation
"""
assert DgatedGeluCastTransposePrimitive.inner_primitive is not None
out, t_out, updated_amax = DgatedGeluCastTransposePrimitive.inner_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary)
return out, t_out, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary):
"""
to describe batch rules for vmap
"""
del static_axis_boundary
_check_valid_batch_dims(batch_dims)
assert DgatedGeluCastTransposePrimitive.outer_primitive is not None
dz, x, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, x_bdim, amax_bdim
return DgatedGeluCastTransposePrimitive.outer_primitive.bind(
dz, x, amax, scale, scale_inv, out_dtype=out_dtype,
static_axis_boundary=x_bdim), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, static_axis_boundary, mesh, arg_infos,
result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
return (out_sharding, tranposed_out_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, static_axis_boundary, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)
def sharded_impl(dz, x, amax, scale, scale_inv):
local_out, local_t_out, local_amax = DgatedGeluCastTransposePrimitive.impl(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_out, local_t_out, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DgatedGeluCastTransposePrimitive)
def dgated_gelu_cast_transpose(
dz: jnp.ndarray, x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, out_dtype: TEDType,
static_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose d_gated_gelu fusion wrapper
Return FP8(dgeglu(inputs))
"""
return DgatedGeluCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary)
# Primitives for SwiGLU and SiLU
class SiluPrimitive(BasePrimitive):
"""
Silu Froward Primitive
"""
name = "te_silu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
@staticmethod
def abstract(x_aval):
"""
gated_silu abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_aval = core.raise_to_shaped(x_aval)
return out_aval
@staticmethod
def lowering(ctx, x):
"""
gated_silu lowering rules
"""
(x_aval,) = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
out_shape = ir_x_shape
out_types = [
ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
]
operands = [x]
operand_shapes = [ir_x_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
hidden_size = ir_x_shape[-1]
batch_size = reduce(operator.mul, ir_x_shape[:-1])
in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype,
in_dtype)
out = custom_caller(SiluPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(x):
assert SiluPrimitive.inner_primitive is not None
out = SiluPrimitive.inner_primitive.bind(x)
return out
@staticmethod
def batcher(batched_args, batch_dims):
"""
gated_silu batcher
"""
_check_valid_batch_dims(batch_dims)
assert SiluPrimitive.outer_primitive is not None
inputs, = batched_args
inputs_bdim, = batch_dims
out_bdims = inputs_bdim
return SiluPrimitive.outer_primitive.bind(inputs), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
"""
gated_silu infer_sharding_from_operands
"""
del result_infos # Unused.
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
return out_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
"""
gated_silu partitioning
"""
del result_infos
x_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
impl = SiluPrimitive.impl
return mesh, impl, out_sharding, arg_shardings
register_primitive(SiluPrimitive)
def silu(inputs: jnp.ndarray) -> jnp.ndarray:
"""
silu wrapper
Return geglu(inputs)
Assume inputs has two dimensions shape and the memory layout is (N..., H)
"""
return SiluPrimitive.outer_primitive.bind(inputs)
class DSiluPrimitive(BasePrimitive):
"""
Dgated Silu Primitive
"""
name = "te_dsilu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
@staticmethod
def abstract(dz_aval, x_aval):
"""
dsilu abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert dz_aval.shape == x_aval.shape
out_aval = core.raise_to_shaped(x_aval)
return out_aval
@staticmethod
def lowering(ctx, dz, x):
"""
dsilu lowering rules
"""
in_aval, gi_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gi_aval.dtype == in_aval.dtype
ir_in_type = ir.RankedTensorType(dz.type)
ir_in_shape = ir_in_type.shape
gi_type = ir.RankedTensorType(x.type)
gi_shape = gi_type.shape
assert ir_in_shape == gi_shape
ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
i_hidden_size = ir_in_shape[-1]
out_dtype = ir_in_type.element_type
out_shape = gi_shape
out_types = [
ir.RankedTensorType.get(out_shape, out_dtype),
]
operands = [dz, x]
operand_shapes = [ir_in_shape, gi_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size),
in_dtype, in_dtype)
out = custom_caller(DSiluPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(dz, x):
"""
dsilu implementation
"""
assert DSiluPrimitive.inner_primitive is not None
dx = DSiluPrimitive.inner_primitive.bind(dz, x)
return dx
@staticmethod
def batcher(batched_args, batch_dims):
"""
dsilu batcher
"""
_check_valid_batch_dims(batch_dims)
assert DSiluPrimitive.outer_primitive is not None
dz, x = batched_args
_, x_bdim = batch_dims
out_bdims = x_bdim
return DSiluPrimitive.outer_primitive.bind(dz, x), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
"""
dsilu infer_sharding_from_operands
"""
del result_infos # Unused.
silu_out_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*silu_out_spec))
return dx_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
"""
dsilu partition
"""
del result_infos
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding
impl = DSiluPrimitive.impl
return mesh, impl, out_shardings, arg_shardings
register_primitive(DSiluPrimitive)
def dsilu(inputs: jnp.ndarray, silu_inputs: jnp.ndarray) -> jnp.ndarray:
"""
dsilu fusion wrapper
Return dgeglu(inputs)
"""
return DSiluPrimitive.outer_primitive.bind(inputs, silu_inputs)
class GatedSiluPrimitive(BasePrimitive):
"""
Gated Silu Froward Primitive
"""
name = "te_gated_silu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
@staticmethod
def abstract(x_aval):
"""
gated_silu abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
x_shape = x_aval.shape
assert x_shape[-2] == 2 # Assume x in (....., 2, hidden)
hidden_size = x_shape[-1]
batch_shapes = x_shape[:-2]
x_shape = x_aval.shape
out_aval = core.raise_to_shaped(x_aval)
out_shape = (batch_shapes) + (hidden_size,)
out_aval = out_aval.update(shape=out_shape, dtype=dtype)
return out_aval
@staticmethod
def lowering(ctx, x):
"""
gated_silu lowering rules
"""
(x_aval,) = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]]
out_types = [
ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
]
operands = [x]
operand_shapes = [ir_x_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
hidden_size = ir_x_shape[-1]
batch_size = reduce(operator.mul, ir_x_shape[:-2])
in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype,
in_dtype)
out = custom_caller(GatedSiluPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(x):
assert GatedSiluPrimitive.inner_primitive is not None
out = GatedSiluPrimitive.inner_primitive.bind(x)
return out
@staticmethod
def batcher(batched_args, batch_dims):
"""
gated_silu batcher
"""
_check_valid_batch_dims(batch_dims)
assert GatedSiluPrimitive.outer_primitive is not None
inputs, = batched_args
inputs_bdim, = batch_dims
out_bdims = inputs_bdim
return GatedSiluPrimitive.outer_primitive.bind(inputs), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
"""
gated_silu infer_sharding_from_operands
"""
del result_infos # Unused.
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
return out_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
"""
gated_silu partitioning
"""
del result_infos
x_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
impl = GatedSiluPrimitive.impl
return mesh, impl, out_sharding, arg_shardings
register_primitive(GatedSiluPrimitive)
def gated_silu(inputs: jnp.ndarray) -> jnp.ndarray:
"""
gated silu wrapper
Return FP8(geglu(inputs))
Assume inputs has two dimensions shape and the memory layout is (N, 2, H)
"""
return GatedSiluPrimitive.outer_primitive.bind(inputs)
class DgatedSiluPrimitive(BasePrimitive):
"""
Dgated Silu Primitive
"""
name = "te_dgated_silu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
@staticmethod
def abstract(dz_aval, x_aval):
"""
dgated_silu abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
for axis in range(len(dz_aval.shape) - 1):
assert dz_aval.shape[axis] == x_aval.shape[axis]
assert x_aval.shape[-2] == 2 # Assume x in (....., 2, hidden)
i_hidden_size = dz_aval.shape[-1]
g_hidden_size = x_aval.shape[-1]
assert i_hidden_size == g_hidden_size
out_aval = core.raise_to_shaped(x_aval)
return out_aval
@staticmethod
def lowering(ctx, dz, x):
"""
dgated_silu lowering rules
"""
in_aval, gi_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gi_aval.dtype == in_aval.dtype
ir_in_type = ir.RankedTensorType(dz.type)
ir_in_shape = ir_in_type.shape
gi_type = ir.RankedTensorType(x.type)
gi_shape = gi_type.shape
for axis in range(len(ir_in_shape) - 1):
assert ir_in_shape[axis] == gi_shape[axis]
ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
i_hidden_size = ir_in_shape[-1]
g_hidden_size = gi_shape[-1]
assert i_hidden_size == g_hidden_size
out_dtype = ir_in_type.element_type
out_shape = gi_shape
out_types = [
ir.RankedTensorType.get(out_shape, out_dtype),
]
operands = [dz, x]
operand_shapes = [ir_in_shape, gi_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size),
in_dtype, in_dtype)
out = custom_caller(DgatedSiluPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(dz, x):
"""
dgated_silu implementation
"""
assert DgatedSiluPrimitive.inner_primitive is not None
dx = DgatedSiluPrimitive.inner_primitive.bind(dz, x)
return dx
@staticmethod
def batcher(batched_args, batch_dims):
"""
dgated_silu batcher
"""
_check_valid_batch_dims(batch_dims)
assert DgatedSiluPrimitive.outer_primitive is not None
dz, x = batched_args
_, x_bdim = batch_dims
out_bdims = x_bdim
return DgatedSiluPrimitive.outer_primitive.bind(dz, x), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
"""
dgated_silu infer_sharding_from_operands
"""
del result_infos # Unused.
silu_out_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*silu_out_spec))
return dx_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
"""
dgated_silu partition
"""
del result_infos
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding
impl = DgatedSiluPrimitive.impl
return mesh, impl, out_shardings, arg_shardings
register_primitive(DgatedSiluPrimitive)
def dgated_silu(inputs: jnp.ndarray, silu_inputs: jnp.ndarray) -> jnp.ndarray:
"""
dgated_silu fusion wrapper
Return dgeglu(inputs)
"""
return DgatedSiluPrimitive.outer_primitive.bind(inputs, silu_inputs)
class SiluFp8Primitive(BasePrimitive):
"""
Silu FP8 Primitive
"""
name = "te_silu_fp8"
multiple_results = True
impl_static_args = (4,) #out_dtype
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype):
"""
te_silu_p abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
# Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return out_aval, updated_amax_aval
@staticmethod
def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
"""
te_gated_silu_p lowering rules
"""
x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
hidden_size = ir_x_shape[-1]
batch_size = reduce(operator.mul, ir_x_shape[:-1])
out_shape = ir_x_shape
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size),
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype))
out = custom_caller(SiluFp8Primitive.name,
args,
opaque,
False,
operand_output_aliases={1: 1})
return out
@staticmethod
def impl(x, amax, scale, scale_inv, out_dtype):
"""
to describe implementation
"""
assert SiluFp8Primitive.inner_primitive is not None
out, updated_amax = SiluFp8Primitive.inner_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
return out, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype):
"""
to describe batch rules for vmap
"""
_check_valid_batch_dims(batch_dims)
assert SiluFp8Primitive.outer_primitive is not None
x, amax, scale, scale_inv = batched_args
x_bdim, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, amax_bdim
return SiluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv,
out_dtype=out_dtype), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (out_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (out_sharding, amax_sharding)
def sharded_impl(x, amax, scale, scale_inv):
local_x, local_amax = SiluFp8Primitive.impl(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_x, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(SiluFp8Primitive)
def silu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
gated silu wrapper
Return FP8(geglu(x))
"""
return SiluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype)
class DSiluDBiasCastTransposePrimitive(BasePrimitive):
"""
DSilu DBias Cast Transpose Primitive
"""
name = "te_dsilu_dbias_cast_transpose"
multiple_results = True
# out_dtype, static_axis_boundary, transpose_axis_boundary
impl_static_args = (5, 6, 7)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
static_axis_boundary, transpose_axis_boundary):
"""
te_dsilu_dbais_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_hidden_szie = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
assert ir_hidden_szie == gi_hidden_size
t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, transpose_axis_boundary)
out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
dbias_shape = (*x_aval.shape[:static_axis_boundary + 1], gi_hidden_size)
dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
wkspace_info, = transformer_engine_jax.get_dact_dbias_ct_workspace_sizes(
x_aval.size // gi_hidden_size,
gi_hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
)
wkspace_aval = x_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
return out, t_out, dbias, updated_amax_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
te_dsilu_dbais_cast_transpose_p outer abstract
"""
out, t_out, dbias, updated_amax_aval, _ = \
DSiluDBiasCastTransposePrimitive.abstract(*args, **kwargs)
return out, t_out, dbias, updated_amax_aval
@staticmethod
def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
te_dgated_silu_cast_transpose_p lowering rules
"""
dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
assert ir_dz_shape == x_shape
batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
ir_hidden_szie = ir_dz_shape[-1]
contracted_x_shape = (batch_szie, ir_hidden_szie)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_x_shape = _multidim_transpose(x_shape, static_axis_boundary,
transpose_axis_boundary)
dbias_shape = (*x_shape[:static_axis_boundary + 1], ir_hidden_szie)
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
]
operands = [dz, x, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_x_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype))
out = custom_caller(DSiluDBiasCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={2: 3})
return out
@staticmethod
def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
to describe implementation
"""
assert DSiluDBiasCastTransposePrimitive.inner_primitive is not None
out, t_out, dbias, updated_amax, _ = DSiluDBiasCastTransposePrimitive.inner_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
return out, t_out, dbias, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
to describe batch rules for vmap
"""
del static_axis_boundary
_check_valid_batch_dims(batch_dims)
assert DSiluDBiasCastTransposePrimitive.outer_primitive is not None
dz, x, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims
# Minus batch dim.
transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
transpose_axis_boundary += 1 # Plus batch dim
out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
return DSiluDBiasCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=x_bdim,
transpose_axis_boundary=transpose_axis_boundary), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh,
arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)
@staticmethod
def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding,
amax_sharding)
def sharded_impl(dz, x, amax, scale, scale_inv):
local_out, local_t_out, local_dbias, local_amax = DSiluDBiasCastTransposePrimitive.impl(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_out, local_t_out, global_dbias, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DSiluDBiasCastTransposePrimitive)
def dsilu_dbias_cast_transpose(
dz: jnp.ndarray,
x: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType,
static_axis_boundary: int,
transpose_axis_boundary: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose dsilu and dbias fusion wrapper
Return FP8(dgeglu(inputs)), dbias
"""
if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes
return DSiluDBiasCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
class GatedSiluFp8Primitive(BasePrimitive):
"""
Gated Silu FP8 Primitive
"""
name = "te_gated_silu_fp8"
multiple_results = True
impl_static_args = (4,) #out_dtype
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype):
"""
te_gated_silu_p abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
# Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
assert x_aval.shape[-2] == 2 # Assume x in (....., 2, hidden)
hidden_size = x_aval.shape[-1]
batch_shape = x_aval.shape[:-2]
out_shape = (batch_shape) + (hidden_size,)
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return out_aval, updated_amax_aval
@staticmethod
def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
"""
te_gated_silu_p lowering rules
"""
x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
hidden_size = ir_x_shape[-1]
batch_shape = ir_x_shape[:-2]
batch_size = reduce(operator.mul, batch_shape)
out_shape = batch_shape + [hidden_size]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, out_shape[-1]),
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype))
out = custom_caller(GatedSiluFp8Primitive.name,
args,
opaque,
False,
operand_output_aliases={1: 1})
return out
@staticmethod
def impl(x, amax, scale, scale_inv, out_dtype):
"""
to describe implementation
"""
assert GatedSiluFp8Primitive.inner_primitive is not None
out, updated_amax = GatedSiluFp8Primitive.inner_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
return out, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype):
"""
to describe batch rules for vmap
"""
_check_valid_batch_dims(batch_dims)
assert GatedSiluFp8Primitive.outer_primitive is not None
x, amax, scale, scale_inv = batched_args
x_bdim, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, amax_bdim
return GatedSiluFp8Primitive.outer_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (out_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (out_sharding, amax_sharding)
def sharded_impl(x, amax, scale, scale_inv):
local_x, local_amax = GatedSiluFp8Primitive.impl(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_x, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(GatedSiluFp8Primitive)
def gated_silu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
gated silu wrapper
Return FP8(geglu(x))
"""
return GatedSiluFp8Primitive.outer_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
class DgatedSiluCastTransposePrimitive(BasePrimitive):
"""
Dgated Silu Cast Transpose Primitive
"""
name = "te_dgated_silu_cast_transpose"
multiple_results = True
impl_static_args = (5, 6) # out_dtype, static_axis_boundary
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
static_axis_boundary):
"""
te_dgated_silu_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert x_aval.shape[-2] == 2 # Linear + GeLU
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_hidden_szie = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
assert ir_hidden_szie == gi_hidden_size
t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, -2)
out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return out, t_out, updated_amax_aval
@staticmethod
def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary):
"""
te_dgated_silu_cast_transpose_p lowering rules
"""
dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
x_batch_size = reduce(operator.mul, x_shape[:-2])
assert dz_batch_szie == x_batch_size
assert x_shape[-2] == 2 # Linear + GeLU
ir_hidden_szie = ir_dz_shape[-1]
gi_hidden_size = x_shape[-1]
assert ir_hidden_szie == gi_hidden_size
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
x_batch_size = reduce(operator.mul, x_shape[:-2])
assert dz_batch_szie == x_batch_size
assert x_shape[-2] == 2 # Linear + GeLU
ir_hidden_szie = ir_dz_shape[-1]
gi_hidden_size = x_shape[-1]
assert ir_hidden_szie == gi_hidden_size
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
......@@ -5922,11 +4444,13 @@ class DgatedSiluCastTransposePrimitive(BasePrimitive):
operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
contracted_x_shape = (x_batch_size, x_shape[-1])
opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype))
opaque = transformer_engine_jax.pack_common_descriptor(
contracted_x_shape,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
act_enum)
out = custom_caller(DgatedSiluCastTransposePrimitive.name,
out = custom_caller(DgatedActLuCastTransposePrimitive.name,
args,
opaque,
False,
......@@ -5935,41 +4459,43 @@ class DgatedSiluCastTransposePrimitive(BasePrimitive):
return out
@staticmethod
def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary):
def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary, act_enum):
"""
to describe implementation
"""
assert DgatedSiluCastTransposePrimitive.inner_primitive is not None
out, t_out, updated_amax = DgatedSiluCastTransposePrimitive.inner_primitive.bind(
assert DgatedActLuCastTransposePrimitive.inner_primitive is not None
out, t_out, updated_amax = DgatedActLuCastTransposePrimitive.inner_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary)
static_axis_boundary=static_axis_boundary,
act_enum=act_enum)
return out, t_out, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary):
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum):
"""
to describe batch rules for vmap
"""
del static_axis_boundary
_check_valid_batch_dims(batch_dims)
assert DgatedSiluCastTransposePrimitive.outer_primitive is not None
assert DgatedActLuCastTransposePrimitive.outer_primitive is not None
dz, x, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, x_bdim, amax_bdim
return DgatedSiluCastTransposePrimitive.outer_primitive.bind(
return DgatedActLuCastTransposePrimitive.outer_primitive.bind(
dz, x, amax, scale, scale_inv, out_dtype=out_dtype,
static_axis_boundary=x_bdim), out_bdims
static_axis_boundary=x_bdim,
act_enum=act_enum), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, static_axis_boundary, mesh, arg_infos,
result_infos):
del out_dtype, result_infos
def infer_sharding_from_operands(out_dtype, static_axis_boundary, act_enum,
mesh, arg_infos, result_infos):
del out_dtype, result_infos, act_enum
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
......@@ -5978,7 +4504,8 @@ class DgatedSiluCastTransposePrimitive(BasePrimitive):
return (out_sharding, tranposed_out_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, static_axis_boundary, mesh, arg_infos, result_infos):
def partition(out_dtype, static_axis_boundary, act_enum,
mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
......@@ -5990,36 +4517,41 @@ class DgatedSiluCastTransposePrimitive(BasePrimitive):
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)
def sharded_impl(dz, x, amax, scale, scale_inv):
local_out, local_t_out, local_amax = DgatedSiluCastTransposePrimitive.impl(
local_out, local_t_out, local_amax = DgatedActLuCastTransposePrimitive.impl(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary)
static_axis_boundary=static_axis_boundary,
act_enum=act_enum)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_out, local_t_out, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DgatedSiluCastTransposePrimitive)
register_primitive(DgatedActLuCastTransposePrimitive)
def dgated_silu_cast_transpose(
dz: jnp.ndarray, x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, out_dtype: TEDType,
static_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
def dgated_act_lu_cast_transpose(
dz: jnp.ndarray, x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, out_dtype: TEDType,
static_axis_boundary: int,
activation_type: Sequence[Union[str, Callable]]
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose d_gated_silu fusion wrapper
Return FP8(dgeglu(inputs))
cast transpose d_gated_act_lu fusion wrapper
Return FP8(dgated_act_lu(inputs))
"""
return DgatedSiluCastTransposePrimitive.outer_primitive.bind(
act_type_id = ActivationEnum[activation_type]
return DgatedActLuCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary)
static_axis_boundary=static_axis_boundary,
act_enum=act_type_id)
......@@ -25,25 +25,14 @@ pybind11::dict Registrations() {
pybind11::dict dict;
dict["te_transpose"] = EncapsulateFunction(Transpose);
dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose);
dict["te_gelu"] = EncapsulateFunction(Gelu);
dict["te_gelu_fp8"] = EncapsulateFunction(GeluFP8);
dict["te_dgelu"] = EncapsulateFunction(DGelu);
dict["te_dgelu_dbias_cast_transpose"] = EncapsulateFunction(DGeluDBiasCastTranspose);
dict["te_act_lu"] = EncapsulateFunction(ActLu);
dict["te_act_lu_fp8"] = EncapsulateFunction(ActLuFP8);
dict["te_dact_lu"] = EncapsulateFunction(DActLu);
dict["te_dbias_cast_transpose"] = EncapsulateFunction(DBiasCastTranspose);
dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu);
dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8);
dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu);
dict["te_dgated_gelu_cast_transpose"] = EncapsulateFunction(DGatedGeluCastTranspose);
// TODO
dict["te_silu"] = EncapsulateFunction(Silu);
dict["te_silu_fp8"] = EncapsulateFunction(SiluFP8);
dict["te_dsilu"] = EncapsulateFunction(DSilu);
dict["te_dsilu_dbias_cast_transpose"] = EncapsulateFunction(DSiluDBiasCastTranspose);
dict["te_gated_silu"] = EncapsulateFunction(GatedSilu);
dict["te_gated_silu_fp8"] = EncapsulateFunction(GatedSiluFP8);
dict["te_dgated_silu"] = EncapsulateFunction(DGatedSilu);
dict["te_dgated_silu_cast_transpose"] = EncapsulateFunction(DGatedSiluCastTranspose);
//
dict["te_dact_lu_dbias_cast_transpose"] = EncapsulateFunction(DActLuDBiasCastTranspose);
dict["te_dgated_act_lu_cast_transpose"] = EncapsulateFunction(DGatedActLuCastTranspose);
dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward);
dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8);
dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward);
......@@ -67,8 +56,11 @@ pybind11::dict Registrations() {
PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("registrations", &Registrations);
m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor);
m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor);
m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor,
pybind11::arg(), pybind11::arg(), pybind11::arg(), pybind11::arg("act_num") = 0);
m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor,
pybind11::arg(), pybind11::arg(), pybind11::arg(),
pybind11::arg(), pybind11::arg(), pybind11::arg("act_num") = 0);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
......@@ -109,6 +101,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD);
pybind11::enum_<NVTE_Activation_Enum>(m, "NVTE_Activation_Enum", pybind11::module_local())
.value("GELU", NVTE_Activation_Enum::GELU)
.value("GEGLU", NVTE_Activation_Enum::GEGLU)
.value("SILU", NVTE_Activation_Enum::SILU)
.value("SWIGLU", NVTE_Activation_Enum::SWIGLU);
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
......
......@@ -37,6 +37,19 @@ std::vector<size_t> MakeShapeVector(NVTEShape shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}
size_t get_activation_len(NVTE_Activation_Enum act_enum) {
switch (act_enum) {
case NVTE_Activation_Enum::GELU: return 1;
case NVTE_Activation_Enum::GEGLU: return 2;
case NVTE_Activation_Enum::SILU: return 1;
case NVTE_Activation_Enum::SWIGLU: return 2;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
return -1;
}
}
template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) {
auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
......@@ -52,23 +65,26 @@ const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
}
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype) {
DType out_dtype, size_t act_enum) {
CustomCallCommonDescriptor desc;
desc.shape.from_vector(shape);
desc.in_dtype = in_dtype;
desc.out_dtype = out_dtype;
desc.act_enum = act_enum;
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype,
DType out_dtype, DType wk_dtype) {
DType out_dtype, DType wk_dtype,
size_t act_enum) {
CustomCallCommonWkDescriptor desc;
desc.shape.from_vector(shape);
desc.wkshape.from_vector(wkshape);
desc.in_dtype = in_dtype;
desc.out_dtype = out_dtype;
desc.wk_dtype = wk_dtype;
desc.act_enum = act_enum;
return PackOpaque(desc);
}
......@@ -170,31 +186,50 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
input_cast_trans_tensor.data(), stream);
}
void GeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
auto input_shape = std::vector<size_t>{m, n};
void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output,
NVTE_Activation_Enum act_enum) {
auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
auto input_tensor = TensorWrapper(input, input_shape,
static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape,
static_cast<DType>(out_dtype), amax,
scale, scale_inverse);
nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
switch (act_enum) {
case NVTE_Activation_Enum::GELU:
nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Enum::GEGLU:
nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Enum::SILU:
nvte_swish(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Enum::SWIGLU:
nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
}
void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);;
GeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output);
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream,
nullptr, nullptr, output, act_enum);
}
void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
......@@ -211,107 +246,91 @@ void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opa
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);;
GeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out,
output);
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream,
scale_inv, amax_out, output, act_enum);
}
void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *gelu_input = buffers[1];
auto *act_input = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);;
auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n};
auto gelu_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n * act_len};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
nvte_dgelu(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), stream);
}
void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *gelu_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
auto *dbias = buffers[7];
float *amax_out = reinterpret_cast<float *>(buffers[8]);
void *workspace_ptr = buffers[9];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
switch (act_enum) {
case NVTE_Activation_Enum::GELU:
nvte_dgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Enum::GEGLU:
nvte_dgeglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Enum::SILU:
nvte_dswish(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
case NVTE_Activation_Enum::SWIGLU:
nvte_dswiglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto gelu_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), gelu_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
}
// HERE
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
TensorWrapper dummy_workspace;
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
dummy_workspace.data(), nullptr);
// For now, all dbias_dact(-s) have the same workspace size
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
auto *output_trans = buffers[5];
auto *dbias = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
void *workspace_ptr = buffers[8];
auto *act_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
auto *dbias = buffers[7];
float *amax_out = reinterpret_cast<float *>(buffers[8]);
void *workspace_ptr = buffers[9];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
......@@ -322,12 +341,15 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);;
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
......@@ -336,81 +358,27 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
workspace.data(), stream);
}
void GatedGeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
auto input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
scale, scale_inverse);
nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
}
void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
GatedGeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr,
output);
}
void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
switch (act_enum) {
case NVTE_Activation_Enum::GELU:
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Enum::SILU:
nvte_cast_transpose_dbias_dswish(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
default:
throw std::runtime_error("Activation Type is not Implemented in DActLuDBiasCastTranspose");
break;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
GatedGeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out,
output);
}
void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *gelu_input = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto gelu_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
nvte_dgeglu(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), stream);
}
void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *gelu_input = buffers[1];
auto *act_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
......@@ -427,124 +395,69 @@ void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *op
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);;
auto input_shape = desc.shape.to_vector();
auto gelu_input_shape = std::vector<size_t>{m, n * 2};
auto act_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto output_trans_shape = std::vector<size_t>{n * 2, m};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
nvte_dgeglu_cast_transpose(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), stream);
}
void SiluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
scale, scale_inverse);
nvte_swish(input_tensor.data(), output_tensor.data(), stream);
}
void Silu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
SiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output);
}
void SiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
switch (act_enum) {
case NVTE_Activation_Enum::GEGLU:
nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
case NVTE_Activation_Enum::SWIGLU:
nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
SiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out,
output);
}
void DSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *silu_input = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto silu_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
nvte_dswish(input_tensor.data(), silu_input_tensor.data(), output_tensor.data(), stream);
}
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
TensorWrapper dummy_workspace;
// For now, all dbias_dact(-s) have the same workspace size
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), dummy_workspace.data(), nullptr);
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}
void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *silu_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
auto *dbias = buffers[7];
float *amax_out = reinterpret_cast<float *>(buffers[8]);
void *workspace_ptr = buffers[9];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
auto *output_trans = buffers[5];
auto *dbias = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
void *workspace_ptr = buffers[8];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
......@@ -556,13 +469,11 @@ void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *op
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto silu_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
......@@ -571,111 +482,9 @@ void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *op
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_cast_transpose_dbias_dswish(input_tensor.data(), silu_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
}
void GatedSiluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
auto input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
scale, scale_inverse);
nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
}
void GatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
GatedSiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr,
output);
}
void GatedSiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
GatedSiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out,
output);
}
void DGatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *silu_input = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto silu_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
nvte_dswiglu(input_tensor.data(), silu_input_tensor.data(), output_tensor.data(), stream);
}
void DGatedSiluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *silu_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = desc.shape.to_vector();
auto silu_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto output_trans_shape = std::vector<size_t>{n * 2, m};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
nvte_dswiglu_cast_transpose(input_tensor.data(), silu_input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), stream);
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
workspace.data(), stream);
}
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
......
......@@ -43,14 +43,24 @@ struct Shape {
}
};
enum class NVTE_Activation_Enum {
GELU,
GEGLU,
SILU,
SWIGLU,
};
size_t get_activation_len(NVTE_Activation_Enum act_enum);
struct CustomCallCommonDescriptor {
Shape shape;
DType in_dtype;
DType out_dtype;
size_t act_enum;
};
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype);
DType out_dtype, size_t act_enum = 0);
struct CustomCallCommonWkDescriptor {
Shape shape;
......@@ -58,11 +68,13 @@ struct CustomCallCommonWkDescriptor {
DType in_dtype;
DType out_dtype;
DType wk_dtype;
size_t act_enum;
};
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype,
DType out_dtype, DType wk_dtype);
const std::vector<size_t> &wkshape,
DType in_dtype, DType out_dtype, DType wk_dtype,
size_t act_enum = 0);
struct CustomCallNormDescriptor {
size_t batch_size;
......@@ -140,17 +152,16 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// TODO (Phuong): Templating these 9x2 rountines before adding ReGLU, QuickGeLU, Squared ReLu
void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
......@@ -159,31 +170,7 @@ pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hi
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void Silu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void SiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void GatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GatedSiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedSiluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
......
......@@ -955,7 +955,6 @@ class LayerNormMLP(TransformerEngineBase):
normalize_acts = tuple(reversed(normalize_acts)
if normalize_acts[0] == 'linear' else normalize_acts)
is_gated = normalize_acts in gated_act_pool
is_act_implemented = normalize_acts in (gated_act_pool + act_pool)
use_fused_layernorm_mlp = fuse_layernorm and is_act_implemented and\
......@@ -1052,8 +1051,8 @@ class LayerNormMLP(TransformerEngineBase):
axes=self.bias_axes_2)
bias_2 = bias_2.astype(self.dtype)
else:
bias_1 = jnp.empty(0, self.dtype)
bias_2 = jnp.empty(0, self.dtype)
bias_1 = None
bias_2 = None
out = fused_layernorm_fp8_mlp(y,
scale,
......@@ -1134,7 +1133,6 @@ class LayerNormMLP(TransformerEngineBase):
x += jnp.reshape(bias_1, bias_1_shape)
x = checkpoint_name(x, ffn1_ckpt_name)
activations = []
if is_act_implemented:
z = activation_lu(x, normalize_acts)
......@@ -1144,8 +1142,8 @@ class LayerNormMLP(TransformerEngineBase):
x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i)
z = functools.reduce(operator.mul, activations)
if not is_gated:
z = jnp.reshape(z, (*z.shape[:-2], -1))
if num_activations == 1:
z = jnp.reshape(z, (*z.shape[:-2], -1))
z = nn.Dropout(rate=self.intermediate_dropout_rate,
broadcast_dims=self.intermediate_hidden_dropout_dims,
......
......@@ -11,14 +11,8 @@ import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
from .cpp_extensions import cast_fp8, transpose, cast_transpose, dbias_cast_transpose
from .cpp_extensions import gelu
from .cpp_extensions import gelu_fp8, dgelu, dgelu_dbias_cast_transpose
from .cpp_extensions import gated_gelu, gated_gelu_fp8
from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose
from .cpp_extensions import silu, silu_fp8
from .cpp_extensions import dsilu, dsilu_dbias_cast_transpose
from .cpp_extensions import gated_silu, gated_silu_fp8
from .cpp_extensions import dgated_silu, dgated_silu_cast_transpose
from .cpp_extensions import act_lu, act_lu_fp8, dact_lu
from .cpp_extensions import dact_lu_dbias_cast_transpose, dgated_act_lu_cast_transpose
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize, dequantize
......@@ -26,44 +20,6 @@ from .layernorm import canonicalize_layernorm_type
from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes
activation_dict = {
('gelu',): {
'fwd': gelu,
"bwd": dgelu
},
('gelu', 'linear'): {
'fwd': gated_gelu,
'bwd': dgated_gelu
},
('silu',): {
'fwd': silu,
"bwd": dsilu
},
('silu', 'linear'): {
'fwd': gated_silu,
'bwd': dgated_silu
}
}
activation_fp8_dict = {
('gelu',): {
'fwd': gelu_fp8,
'bwd': dgelu_dbias_cast_transpose
},
('gelu', 'linear'): {
'fwd': gated_gelu_fp8,
'bwd': dgated_gelu_cast_transpose
},
('silu',): {
'fwd': silu_fp8,
'bwd': dsilu_dbias_cast_transpose
},
('silu', 'linear'): {
'fwd': gated_silu_fp8,
'bwd': dgated_silu_cast_transpose
}
}
def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
"""
......@@ -84,7 +40,7 @@ def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable
def _activation_lu_fwd_rule(x, activation_type):
fwd_output = activation_dict[activation_type]["fwd"](x)
fwd_output = act_lu(x, activation_type)
return fwd_output, (x,)
......@@ -92,7 +48,7 @@ def _activation_lu_bwd_rule(activation_type, ctx, g):
x, = ctx
assert x.dtype == g.dtype
dx = activation_dict[activation_type]["bwd"](g, x)
dx = dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape)
return (dx,)
......@@ -106,7 +62,7 @@ def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, sca
"""
Activation Unit
"""
transpose_indices = (1, 2, 0) if len(activation_type) > 1 else (2, 0, 1)
transpose_indices = (1, 2, 0)
dx_trans_no_use = jnp.empty([x.shape[i] for i in transpose_indices], dtype=x.dtype)
dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
......@@ -127,19 +83,15 @@ def _activation_lu_fp8(x: jnp.ndarray, dx_trans_no_use: jnp.ndarray, dbias_no_us
return output
def _activation_lu_fp8_fwd_rule(
x,
dx_trans_no_use, # pylint: disable=unused-argument
dbias_no_use, # pylint: disable=unused-argument
amax,
scale,
scale_inv,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
activation_type):
activation_lu_out, _ = activation_fp8_dict[activation_type]["fwd"](x, amax, scale, scale_inv,
fwd_dtype)
def _activation_lu_fp8_fwd_rule(x,
dx_trans_no_use, # pylint: disable=unused-argument
dbias_no_use, # pylint: disable=unused-argument
amax,
scale, scale_inv,
fwd_dtype, bwd_dtype, # pylint: disable=unused-argument
activation_type):
activation_lu_out, _ = act_lu_fp8(x, amax, scale, scale_inv, fwd_dtype,
activation_type)
activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
ctx = (x, amax, scale, scale_inv)
return activation_lu_out, ctx
......@@ -153,14 +105,14 @@ def _activation_lu_fp8_bwd_rule(
g):
x, amax, scale, scale_inv = ctx
activation_lu_fp8_bwd = activation_fp8_dict[activation_type]["bwd"]
if len(activation_type) > 1: #gated, no bias
if len(activation_type) > 1: #gated, no bias
dactivation_lu, dactivation_lu_trans, amax_out = \
activation_lu_fp8_bwd(g, x, amax, scale, scale_inv, bwd_dtype, -1)
dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv, bwd_dtype, -1, activation_type)
dbias = jnp.empty(x.shape[-1], x.dtype)
else:
else: #not gated, with bias
dactivation_lu, dactivation_lu_trans, dbias, amax_out = \
activation_lu_fp8_bwd(g, x, amax, scale, scale_inv, bwd_dtype, -1)
dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, bwd_dtype,
-1, -2, activation_type)
dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out, scale, scale_inv)
......@@ -262,7 +214,6 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
activation_type,
use_bias):
is_gated = len(activation_type) > 1
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
# Kernel_2 should be in shape of (Hidden_in, Hidden_out)
......@@ -276,15 +227,9 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
assert kernel_1.shape[-1] == kernel_2.shape[0]
# Squeeze act axis
# (hidden_in, 1, hidden_out) -> (hidden_in, hidden_out)
if not is_gated:
kernel_1 = jnp.squeeze(kernel_1, axis=-2)
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
amax = FP8Helper.update_amax_history(amax)
......@@ -337,8 +282,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
(x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
if use_bias:
bias_1_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1.shape
dot_1_output += jnp.reshape(bias_1, bias_1_shape)
bias_1_shape = bias_1.shape
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
else:
bias_1_shape = None
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
......@@ -347,12 +295,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
activation_lu_out_scale = scale[gemm2_x_idx]
activation_lu_out_scale_inv = scale_inv[gemm2_x_idx]
activation_lu_fwd_fp8 = activation_fp8_dict[activation_type]["fwd"]
# (batch..., hidden_in) -> (batch..., hidden)
casted_activation_lu_out, updated_activation_lu_amax = \
activation_lu_fwd_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale,
activation_lu_out_scale_inv, fwd_dtype)
act_lu_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale,
activation_lu_out_scale_inv, fwd_dtype, activation_type)
casted_activation_lu_out = with_sharding_constraint_by_logical_axes(
casted_activation_lu_out, dot_2_input_axes)
......@@ -370,15 +317,18 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
if use_bias:
bias_2_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2.shape
dot_2_output += jnp.reshape(bias_2, bias_2_shape)
bias_2_shape = bias_2.shape
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
else:
bias_2_shape = None
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1,
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax,
x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape, maybe_fp32_to_fm32)
x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32)
return dot_2_output, ctx
......@@ -403,8 +353,6 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32 = ctx
is_gated = len(activation_type) > 1
gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)
grad_amax = amax[gemm2_grad_idx, 0:1]
......@@ -413,7 +361,6 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
# Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
if use_bias:
casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \
dbias_cast_transpose(grad, grad_amax, grad_scale,
......@@ -427,7 +374,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
grad_scale_inv, bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
dbias_2 = jnp.empty(bias_2_shape, grad.dtype)
dbias_2 = None
casted_activation_lu_out_t = transpose(casted_activation_lu_out,
static_axis_boundary=-1,
......@@ -453,11 +400,9 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dactivation_lu_scale = scale[gemm1_grad_idx]
dactivation_lu_scale_inv = scale_inv[gemm1_grad_idx]
dactivation_lu_cast_transpose = activation_fp8_dict[activation_type]["bwd"]
dactivation_lu = activation_dict[activation_type]["bwd"](dgrad_2, dot_1_output)
if is_gated:
if len(activation_type) > 1: # if gated
if use_bias:
dactivation_lu = dact_lu(dgrad_2, dot_1_output, activation_type)
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
dbias_cast_transpose(
dactivation_lu,
......@@ -470,19 +415,20 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else:
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
dactivation_lu_cast_transpose(
dgated_act_lu_cast_transpose(
dgrad_2,
dot_1_output,
dactivation_lu_amax,
dactivation_lu_scale,
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1)
dbias_1 = jnp.empty(bias_1_shape, bwd_dtype)
static_axis_boundary=-1,
activation_type=activation_type)
dbias_1 = None
else:
if use_bias:
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
dactivation_lu_cast_transpose(
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax=\
dact_lu_dbias_cast_transpose(
dgrad_2,
dot_1_output,
dactivation_lu_amax,
......@@ -490,9 +436,11 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
transpose_axis_boundary=-2,
activation_type=activation_type)
dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else:
dactivation_lu = dact_lu(dgrad_2, dot_1_output, activation_type)
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
cast_transpose(
dactivation_lu,
......@@ -501,28 +449,20 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
dbias_1 = jnp.empty(bias_1_shape, bwd_dtype)
transpose_axis_boundary=-2)
dbias_1 = None
ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
# (hidden, batch...) x (hidden, batch...)
gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
xt_batch_dims_2 = xt_batch_dims if not is_gated \
else tuple(i + 1 for i in xt_batch_dims)
xt_batch_dims_2 = tuple(i + 1 for i in xt_batch_dims)
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv,
dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# Expand act axis to match the shape with the given kernel_1
if not is_gated:
wgrad_1 = jnp.expand_dims(wgrad_1, axis=-2)
# (batch..., hidden_out) x (hidden_in, hidden_out)
if is_gated:
x_contracting_dims = ((min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims),
(1, 2))
else:
x_contracting_dims = (x_contracting_dims, (1,))
x_contracting_dims = ((min(x_contracting_dims),) + tuple(
i + 1 for i in x_contracting_dims), (1,2))
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv,
kernel_1_scale_inv, grad.dtype, x_contracting_dims,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment