"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "ce0b46c4d1d9d98104b7c6edce445b488f346279"
Unverified Commit aad4e173 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Generalizing Activation Primitives (#810)



* templated primitives and respective C++ functions
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fixes for LayerNormMLP, tests in test_custom_compute all passed
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* added default arg for pybind get_workspace_size funcs
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fixes for TestTransFormer with non-gated act tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* renamed gelu to act
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* improved enum implementation, avoid using magic numbers
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Exposed C++ ActivationEnum to python side
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Changed error messages
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* changed conditional check on input shape for dbias_cast_transpose
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* changed dtype (tol) for bias grad tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fixes so that layer_norm_fp8_mlp can take bias = None
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Set bias = None in flax modules
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 2045a426
...@@ -194,8 +194,8 @@ class TestFP8Dot: ...@@ -194,8 +194,8 @@ class TestFP8Dot:
b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16) b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16) b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
else: else:
b1 = jax.random.normal(subkeys[3], (0,), jnp.bfloat16) b1 = None
b2 = jax.random.normal(subkeys[4], (0,), jnp.bfloat16) b2 = None
init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2) init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
init_fp8_metas_amax = jnp.zeros( init_fp8_metas_amax = jnp.zeros(
...@@ -300,19 +300,19 @@ class TestFP8Dot: ...@@ -300,19 +300,19 @@ class TestFP8Dot:
assert_allclose(jnp.asarray(primitive_k1_grad, np.float32), assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32), jnp.asarray(ref_k1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE) dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_s_grad, np.float32), assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32), jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE) dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
if use_bias: if use_bias:
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=jnp.bfloat16)
assert_allclose(jnp.asarray(primitive_b2_grad, np.float32), assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
jnp.asarray(ref_b2_grad, np.float32), jnp.asarray(ref_b2_grad, np.float32),
dtype=jnp.bfloat16) dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
@pytest.fixture(name="random_inputs") @pytest.fixture(name="random_inputs")
...@@ -341,13 +341,14 @@ class TestActivationLu: ...@@ -341,13 +341,14 @@ class TestActivationLu:
def primitive_func(self, inputs): def primitive_func(self, inputs):
return jnp.mean(activation_lu(inputs, activation_type = self.activation_type)) return jnp.mean(activation_lu(inputs, activation_type = self.activation_type))
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)]) @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize('activation_type', [('gelu',), @pytest.mark.parametrize('activation_type', [('gelu',),
('gelu', 'linear'), ('gelu', 'linear'),
('silu',), ('silu',),
('silu', 'linear')]) ('silu', 'linear')])
def test_activation_lu(self, random_inputs, activation_type): def test_activation_lu(self, random_inputs, activation_type):
x = random_inputs x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1)
self.activation_type = activation_type self.activation_type = activation_type
value_n_grad_primitive_func = jit( value_n_grad_primitive_func = jit(
...@@ -355,8 +356,6 @@ class TestActivationLu: ...@@ -355,8 +356,6 @@ class TestActivationLu:
prim_out, (prim_grad,) = value_n_grad_primitive_func(x) prim_out, (prim_grad,) = value_n_grad_primitive_func(x)
ref_out, (ref_grad,) = self.ref_func(x, activation_type) ref_out, (ref_grad,) = self.ref_func(x, activation_type)
""" prim_grad, = prim_grad """
""" ref_grad, = ref_grad """
assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
...@@ -372,7 +371,7 @@ class TestActivationLuFP8(TestActivationLu): ...@@ -372,7 +371,7 @@ class TestActivationLuFP8(TestActivationLu):
activation_type = self.activation_type)) activation_type = self.activation_type))
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)]) @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize('activation_type', [('gelu',), @pytest.mark.parametrize('activation_type', [('gelu',),
('gelu', 'linear'), ('gelu', 'linear'),
('silu',), ('silu',),
...@@ -384,6 +383,7 @@ class TestActivationLuFP8(TestActivationLu): ...@@ -384,6 +383,7 @@ class TestActivationLuFP8(TestActivationLu):
self.activation_type = activation_type self.activation_type = activation_type
x = random_inputs x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1)
value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0, 1, 2, 3, 4, 5,))) value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0, 1, 2, 3, 4, 5,)))
......
...@@ -529,11 +529,12 @@ void cast_transpose_dbias(const Tensor &input, ...@@ -529,11 +529,12 @@ void cast_transpose_dbias(const Tensor &input,
Tensor *dbias, Tensor *dbias,
Tensor *workspace, Tensor *workspace,
cudaStream_t stream) { cudaStream_t stream) {
// TODO if (workspace->data.dptr != nullptr) {
// CheckInputTensor(input, "cast_transpose_dbias_input"); CheckInputTensor(input, "cast_transpose_dbias_input");
// CheckOutputTensor(*cast_output, "cast_output"); CheckOutputTensor(*cast_output, "cast_output");
// CheckOutputTensor(*transposed_output, "transposed_output"); CheckOutputTensor(*transposed_output, "transposed_output");
// CheckOutputTensor(*dbias, "dbias"); CheckOutputTensor(*dbias, "dbias");
}
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""JAX te custom call""" """JAX te custom call"""
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple from typing import Tuple, Sequence, Union, Callable
from functools import partial, reduce from functools import partial, reduce
import operator import operator
import os import os
...@@ -27,6 +27,7 @@ from transformer_engine_jax import NVTE_Bias_Type ...@@ -27,6 +27,7 @@ from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend from transformer_engine_jax import NVTE_Fused_Attn_Backend
from transformer_engine_jax import NVTE_Activation_Enum
from .sharding import all_reduce_max_along_all_axes_except_PP from .sharding import all_reduce_max_along_all_axes_except_PP
from .sharding import all_reduce_sum_along_dp_fsdp from .sharding import all_reduce_sum_along_dp_fsdp
...@@ -124,6 +125,14 @@ def _check_valid_batch_dims(bdims): ...@@ -124,6 +125,14 @@ def _check_valid_batch_dims(bdims):
f"but got {dim=}" f"but got {dim=}"
ActivationEnum = {
('gelu',): NVTE_Activation_Enum.GELU,
('gelu', 'linear'): NVTE_Activation_Enum.GEGLU,
('silu',): NVTE_Activation_Enum.SILU,
('silu', 'linear'): NVTE_Activation_Enum.SWIGLU
}
class BasePrimitive(metaclass=ABCMeta): class BasePrimitive(metaclass=ABCMeta):
""" """
jax primitive jax primitive
...@@ -2556,244 +2565,28 @@ def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda ...@@ -2556,244 +2565,28 @@ def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda
is_training=is_training) is_training=is_training)
class GeluPrimitive(BasePrimitive): class ActLuPrimitive(BasePrimitive):
""" """
Gelu Froward Primitive Activation Forward Primitive
""" """
name = "te_gelu" name = "te_act_lu"
multiple_results = False multiple_results = False
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
impl_static_args = () impl_static_args = (1,)
@staticmethod @staticmethod
def abstract(x_aval): def abstract(x_aval, *, act_enum): # pylint: disable=unused-argument
""" """
gated_gelu abstract act_lu abstract
""" """
dtype = dtypes.canonicalize_dtype(x_aval.dtype) dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_aval = core.raise_to_shaped(x_aval)
return out_aval
@staticmethod
def lowering(ctx, x):
"""
gated_gelu lowering rules
"""
(x_aval,) = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
out_shape = ir_x_shape
out_types = [
ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
]
operands = [x]
operand_shapes = [ir_x_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
hidden_size = ir_x_shape[-1]
batch_size = reduce(operator.mul, ir_x_shape[:-1])
in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype,
in_dtype)
out = custom_caller(GeluPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(x):
assert GeluPrimitive.inner_primitive is not None
out = GeluPrimitive.inner_primitive.bind(x)
return out
@staticmethod
def batcher(batched_args, batch_dims):
"""
gated_gelu batcher
"""
_check_valid_batch_dims(batch_dims)
assert GeluPrimitive.outer_primitive is not None
inputs, = batched_args
inputs_bdim, = batch_dims
out_bdims = inputs_bdim
return GeluPrimitive.outer_primitive.bind(inputs), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
"""
gated_gelu infer_sharding_from_operands
"""
del result_infos # Unused.
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
return out_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
"""
gated_gelu partitioning
"""
del result_infos
x_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
impl = GeluPrimitive.impl
return mesh, impl, out_sharding, arg_shardings
register_primitive(GeluPrimitive)
def gelu(inputs: jnp.ndarray) -> jnp.ndarray:
"""
gelu wrapper
Return geglu(inputs)
Assume inputs has two dimensions shape and the memory layout is (N..., H)
"""
return GeluPrimitive.outer_primitive.bind(inputs)
class DGeluPrimitive(BasePrimitive):
"""
Dgated Gelu Primitive
"""
name = "te_dgelu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
@staticmethod
def abstract(dz_aval, x_aval):
"""
dgelu abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert dz_aval.shape == x_aval.shape
out_aval = core.raise_to_shaped(x_aval)
return out_aval
@staticmethod
def lowering(ctx, dz, x):
"""
dgelu lowering rules
"""
in_aval, gi_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gi_aval.dtype == in_aval.dtype
ir_in_type = ir.RankedTensorType(dz.type)
ir_in_shape = ir_in_type.shape
gi_type = ir.RankedTensorType(x.type)
gi_shape = gi_type.shape
assert ir_in_shape == gi_shape
ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
i_hidden_size = ir_in_shape[-1]
out_dtype = ir_in_type.element_type
out_shape = gi_shape
out_types = [
ir.RankedTensorType.get(out_shape, out_dtype),
]
operands = [dz, x]
operand_shapes = [ir_in_shape, gi_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size),
in_dtype, in_dtype)
out = custom_caller(DGeluPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(dz, x):
"""
dgelu implementation
"""
assert DGeluPrimitive.inner_primitive is not None
dx = DGeluPrimitive.inner_primitive.bind(dz, x)
return dx
@staticmethod
def batcher(batched_args, batch_dims):
"""
dgelu batcher
"""
_check_valid_batch_dims(batch_dims)
assert DGeluPrimitive.outer_primitive is not None
dz, x = batched_args
_, x_bdim = batch_dims
out_bdims = x_bdim
return DGeluPrimitive.outer_primitive.bind(dz, x), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
"""
dgelu infer_sharding_from_operands
"""
del result_infos # Unused.
gelu_out_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*gelu_out_spec))
return dx_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
"""
dgelu partition
"""
del result_infos
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding
impl = DGeluPrimitive.impl
return mesh, impl, out_shardings, arg_shardings
register_primitive(DGeluPrimitive)
def dgelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray:
"""
dgelu fusion wrapper
Return dgeglu(inputs)
"""
return DGeluPrimitive.outer_primitive.bind(inputs, gelu_inputs)
class GatedGeluPrimitive(BasePrimitive):
"""
Gated Gelu Froward Primitive
"""
name = "te_gated_gelu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
@staticmethod
def abstract(x_aval):
"""
gated_gelu abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
x_shape = x_aval.shape x_shape = x_aval.shape
assert x_shape[-2] == 2 # Assume x in (....., 2, hidden) assert (x_shape[-2] == 2 or x_shape[-2] == 1)
hidden_size = x_shape[-1] hidden_size = x_shape[-1]
batch_shapes = x_shape[:-2] batch_shapes = x_shape[:-2]
x_shape = x_aval.shape
out_aval = core.raise_to_shaped(x_aval) out_aval = core.raise_to_shaped(x_aval)
out_shape = (batch_shapes) + (hidden_size,) out_shape = (batch_shapes) + (hidden_size,)
out_aval = out_aval.update(shape=out_shape, dtype=dtype) out_aval = out_aval.update(shape=out_shape, dtype=dtype)
...@@ -2801,9 +2594,9 @@ class GatedGeluPrimitive(BasePrimitive): ...@@ -2801,9 +2594,9 @@ class GatedGeluPrimitive(BasePrimitive):
return out_aval return out_aval
@staticmethod @staticmethod
def lowering(ctx, x): def lowering(ctx, x, *, act_enum):
""" """
gated_gelu lowering rules act_lu lowering rules
""" """
(x_aval,) = ctx.avals_in (x_aval,) = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
...@@ -2821,100 +2614,101 @@ class GatedGeluPrimitive(BasePrimitive): ...@@ -2821,100 +2614,101 @@ class GatedGeluPrimitive(BasePrimitive):
hidden_size = ir_x_shape[-1] hidden_size = ir_x_shape[-1]
batch_size = reduce(operator.mul, ir_x_shape[:-2]) batch_size = reduce(operator.mul, ir_x_shape[:-2])
in_dtype = jax_dtype_to_te_dtype(x_aval.dtype) in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype, opaque = transformer_engine_jax.pack_common_descriptor(
in_dtype) (batch_size, hidden_size), in_dtype, in_dtype, act_enum)
out = custom_caller(GatedGeluPrimitive.name, args, opaque, False) out = custom_caller(ActLuPrimitive.name, args, opaque, False)
return [out] return [out]
@staticmethod @staticmethod
def impl(x): def impl(x, act_enum):
assert GatedGeluPrimitive.inner_primitive is not None assert ActLuPrimitive.inner_primitive is not None
out = GatedGeluPrimitive.inner_primitive.bind(x) out = ActLuPrimitive.inner_primitive.bind(x, act_enum=act_enum)
return out return out
@staticmethod @staticmethod
def batcher(batched_args, batch_dims): def batcher(batched_args, batch_dims, *, act_enum):
""" """
gated_gelu batcher act_lu batcher
""" """
_check_valid_batch_dims(batch_dims) _check_valid_batch_dims(batch_dims)
assert GatedGeluPrimitive.outer_primitive is not None assert ActLuPrimitive.outer_primitive is not None
inputs, = batched_args inputs, = batched_args
inputs_bdim, = batch_dims inputs_bdim, = batch_dims
out_bdims = inputs_bdim out_bdims = inputs_bdim
return GatedGeluPrimitive.outer_primitive.bind(inputs), out_bdims return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_enum), out_bdims
@staticmethod @staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos): def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos):
""" """
gated_gelu infer_sharding_from_operands act_lu infer_sharding_from_operands
""" """
del result_infos # Unused. del result_infos, act_enum # Unused.
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
return out_sharding return out_sharding
@staticmethod @staticmethod
def partition(mesh, arg_infos, result_infos): def partition(act_enum, mesh, arg_infos, result_infos):
""" """
gated_gelu partitioning act_lu partitioning
""" """
del result_infos del result_infos, act_enum
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
impl = GatedGeluPrimitive.impl impl = ActLuPrimitive.impl
return mesh, impl, out_sharding, arg_shardings return mesh, impl, out_sharding, arg_shardings
register_primitive(GatedGeluPrimitive) register_primitive(ActLuPrimitive)
def gated_gelu(inputs: jnp.ndarray) -> jnp.ndarray: def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]) -> jnp.ndarray:
""" """
gated gelu wrapper act_lu wrapper
Return FP8(geglu(inputs)) Return act_lu(inputs)
Assume inputs has two dimensions shape and the memory layout is (N, 2, H) Input shape: (N, 1, H) for non-gated activations
(N, 2, H) for gated activations
""" """
return GatedGeluPrimitive.outer_primitive.bind(inputs) act_type_id = ActivationEnum[activation_type]
return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id)
class DgatedGeluPrimitive(BasePrimitive): class DActLuPrimitive(BasePrimitive):
""" """
Dgated Gelu Primitive Dgated ActLu Primitive
""" """
name = "te_dgated_gelu" name = "te_dact_lu"
multiple_results = False multiple_results = False
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
impl_static_args = () impl_static_args = (2,)
@staticmethod @staticmethod
def abstract(dz_aval, x_aval): def abstract(dz_aval, x_aval, *, act_enum): # pylint: disable=unused-argument
""" """
dgated_gelu abstract dact_lu abstract
""" """
dtype = dtypes.canonicalize_dtype(dz_aval.dtype) dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype assert x_aval.dtype == dtype
for axis in range(len(dz_aval.shape) - 1): for axis in range(len(dz_aval.shape) - 1):
assert dz_aval.shape[axis] == x_aval.shape[axis] assert dz_aval.shape[axis] == x_aval.shape[axis]
assert (x_aval.shape[-2] == 2 or x_aval.shape[-2] == 1)
assert x_aval.shape[-2] == 2 # Assume x in (....., 2, hidden)
i_hidden_size = dz_aval.shape[-1] i_hidden_size = dz_aval.shape[-1]
g_hidden_size = x_aval.shape[-1] g_hidden_size = x_aval.shape[-1]
assert i_hidden_size == g_hidden_size assert i_hidden_size == g_hidden_size
out_aval = core.raise_to_shaped(x_aval) out_aval = core.raise_to_shaped(x_aval)
return out_aval return out_aval
@staticmethod @staticmethod
def lowering(ctx, dz, x): def lowering(ctx, dz, x, *, act_enum):
""" """
dgated_gelu lowering rules dact_lu lowering rules
""" """
in_aval, gi_aval = ctx.avals_in in_aval, gi_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
...@@ -2942,66 +2736,68 @@ class DgatedGeluPrimitive(BasePrimitive): ...@@ -2942,66 +2736,68 @@ class DgatedGeluPrimitive(BasePrimitive):
in_dtype = jax_dtype_to_te_dtype(in_aval.dtype) in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size), opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size),
in_dtype, in_dtype) in_dtype, in_dtype, act_enum)
out = custom_caller(DgatedGeluPrimitive.name, args, opaque, False) out = custom_caller(DActLuPrimitive.name, args, opaque, False)
return [out] return [out]
@staticmethod @staticmethod
def impl(dz, x): def impl(dz, x, act_enum):
""" """
dgated_gelu implementation dact_lu implementation
""" """
assert DgatedGeluPrimitive.inner_primitive is not None assert DActLuPrimitive.inner_primitive is not None
dx = DgatedGeluPrimitive.inner_primitive.bind(dz, x) dx = DActLuPrimitive.inner_primitive.bind(dz, x, act_enum=act_enum)
return dx return dx
@staticmethod @staticmethod
def batcher(batched_args, batch_dims): def batcher(batched_args, batch_dims, *, act_enum):
""" """
dgated_gelu batcher dact_lu batcher
""" """
_check_valid_batch_dims(batch_dims) _check_valid_batch_dims(batch_dims)
assert DgatedGeluPrimitive.outer_primitive is not None assert DActLuPrimitive.outer_primitive is not None
dz, x = batched_args dz, x = batched_args
_, x_bdim = batch_dims _, x_bdim = batch_dims
out_bdims = x_bdim out_bdims = x_bdim
return DgatedGeluPrimitive.outer_primitive.bind(dz, x), out_bdims return DActLuPrimitive.outer_primitive.bind(dz, x, act_enum=act_enum), out_bdims
@staticmethod @staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos): def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos):
""" """
dgated_gelu infer_sharding_from_operands dact_lu infer_sharding_from_operands
""" """
del result_infos # Unused. del result_infos, act_enum # Unused.
gelu_out_spec = get_padded_spec(arg_infos[1]) act_lu_out_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*gelu_out_spec)) dx_sharding = NamedSharding(mesh, PartitionSpec(*act_lu_out_spec))
return dx_sharding return dx_sharding
@staticmethod @staticmethod
def partition(mesh, arg_infos, result_infos): def partition(act_enum, mesh, arg_infos, result_infos):
""" """
dgated_gelu partition dact_lu partition
""" """
del result_infos del result_infos, act_enum
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding out_shardings = dx_sharding
impl = DgatedGeluPrimitive.impl impl = DActLuPrimitive.impl
return mesh, impl, out_shardings, arg_shardings return mesh, impl, out_shardings, arg_shardings
register_primitive(DgatedGeluPrimitive) register_primitive(DActLuPrimitive)
def dgated_gelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray: def dact_lu(inputs: jnp.ndarray, act_lu_inputs: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]]) -> jnp.ndarray:
""" """
dgated_gelu fusion wrapper dact_lu fusion wrapper
Return dgeglu(inputs) Return dgated_act_lu(inputs)
""" """
return DgatedGeluPrimitive.outer_primitive.bind(inputs, gelu_inputs) act_type_id = ActivationEnum[activation_type]
return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id)
def _normalize_axis_boundary(axis, ndim): def _normalize_axis_boundary(axis, ndim):
...@@ -3958,20 +3754,21 @@ def rmsnorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, amax: jnp.ndarray, scale ...@@ -3958,20 +3754,21 @@ def rmsnorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, amax: jnp.ndarray, scale
epsilon=epsilon) epsilon=epsilon)
class GeluFp8Primitive(BasePrimitive): class ActLuFp8Primitive(BasePrimitive):
""" """
Gelu FP8 Primitive ActLu FP8 Primitive
""" """
name = "te_gelu_fp8" name = "te_act_lu_fp8"
multiple_results = True multiple_results = True
impl_static_args = (4,) #out_dtype impl_static_args = (4, 5) #out_dtype, act_enum
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype): def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
act_enum): # pylint: disable=unused-argument
""" """
te_gelu_p abstract te_act_lu_p abstract
""" """
dtype = dtypes.canonicalize_dtype(x_aval.dtype) dtype = dtypes.canonicalize_dtype(x_aval.dtype)
# Currently only support casting to E4M3 only in C side. # Currently only support casting to E4M3 only in C side.
...@@ -3981,15 +3778,19 @@ class GeluFp8Primitive(BasePrimitive): ...@@ -3981,15 +3778,19 @@ class GeluFp8Primitive(BasePrimitive):
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) assert (x_aval.shape[-2] == 1 or x_aval.shape[-2] == 2)
hidden_size = x_aval.shape[-1]
batch_shape = x_aval.shape[:-2]
out_shape = (batch_shape) + (hidden_size,)
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return out_aval, updated_amax_aval return out_aval, updated_amax_aval
@staticmethod @staticmethod
def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype): def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, act_enum):
""" """
te_gated_gelu_p lowering rules te_gated_act_lu_p lowering rules
""" """
x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
...@@ -4006,8 +3807,9 @@ class GeluFp8Primitive(BasePrimitive): ...@@ -4006,8 +3807,9 @@ class GeluFp8Primitive(BasePrimitive):
ir_scale_inv_shape = ir_amax_shape ir_scale_inv_shape = ir_amax_shape
hidden_size = ir_x_shape[-1] hidden_size = ir_x_shape[-1]
batch_size = reduce(operator.mul, ir_x_shape[:-1]) batch_shape = ir_x_shape[:-2]
out_shape = ir_x_shape batch_size = reduce(operator.mul, batch_shape)
out_shape = batch_shape + [hidden_size]
out_types = [ out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype), ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
...@@ -4016,11 +3818,13 @@ class GeluFp8Primitive(BasePrimitive): ...@@ -4016,11 +3818,13 @@ class GeluFp8Primitive(BasePrimitive):
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), opaque = transformer_engine_jax.pack_common_descriptor((
jax_dtype_to_te_dtype(x_aval.dtype), batch_size, hidden_size),
jax_dtype_to_te_dtype(out_dtype)) jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
act_enum)
out = custom_caller(GeluFp8Primitive.name, out = custom_caller(ActLuFp8Primitive.name,
args, args,
opaque, opaque,
False, False,
...@@ -4029,55 +3833,58 @@ class GeluFp8Primitive(BasePrimitive): ...@@ -4029,55 +3833,58 @@ class GeluFp8Primitive(BasePrimitive):
return out return out
@staticmethod @staticmethod
def impl(x, amax, scale, scale_inv, out_dtype): def impl(x, amax, scale, scale_inv, out_dtype, act_enum):
""" """
to describe implementation to describe implementation
""" """
assert GeluFp8Primitive.inner_primitive is not None assert ActLuFp8Primitive.inner_primitive is not None
out, updated_amax = GeluFp8Primitive.inner_primitive.bind(x, out, updated_amax = ActLuFp8Primitive.inner_primitive.bind(x,
amax, amax,
scale, scale,
scale_inv, scale_inv,
out_dtype=out_dtype) out_dtype=out_dtype,
act_enum=act_enum)
return out, updated_amax return out, updated_amax
@staticmethod @staticmethod
def batcher(batched_args, batch_dims, *, out_dtype): def batcher(batched_args, batch_dims, *, out_dtype, act_enum):
""" """
to describe batch rules for vmap to describe batch rules for vmap
""" """
_check_valid_batch_dims(batch_dims) _check_valid_batch_dims(batch_dims)
assert GeluFp8Primitive.outer_primitive is not None assert ActLuFp8Primitive.outer_primitive is not None
x, amax, scale, scale_inv = batched_args x, amax, scale, scale_inv = batched_args
x_bdim, amax_bdim, _, _ = batch_dims x_bdim, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, amax_bdim out_bdims = x_bdim, amax_bdim
return GeluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, return ActLuFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv,
out_dtype=out_dtype), out_bdims out_dtype=out_dtype,
act_enum=act_enum), out_bdims
@staticmethod @staticmethod
def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos): def infer_sharding_from_operands(out_dtype, act_enum, mesh, arg_infos, result_infos):
del out_dtype, result_infos del out_dtype, result_infos, act_enum
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (out_sharding, amax_sharding) return (out_sharding, amax_sharding)
@staticmethod @staticmethod
def partition(out_dtype, mesh, arg_infos, result_infos): def partition(out_dtype, act_enum, mesh, arg_infos, result_infos):
del result_infos del result_infos
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (out_sharding, amax_sharding) out_shardings = (out_sharding, amax_sharding)
def sharded_impl(x, amax, scale, scale_inv): def sharded_impl(x, amax, scale, scale_inv):
local_x, local_amax = GeluFp8Primitive.impl(x, local_x, local_amax = ActLuFp8Primitive.impl(x,
amax, amax,
scale, scale,
scale_inv, scale_inv,
out_dtype=out_dtype) out_dtype=out_dtype,
act_enum=act_enum)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_x, global_updated_amax return local_x, global_updated_amax
...@@ -4085,34 +3892,40 @@ class GeluFp8Primitive(BasePrimitive): ...@@ -4085,34 +3892,40 @@ class GeluFp8Primitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(GeluFp8Primitive) register_primitive(ActLuFp8Primitive)
def gelu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, def act_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: out_dtype: jnp.dtype, activation_type: Sequence[Union[str, Callable]]
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
""" """
gated gelu wrapper act wrapper
Return FP8(geglu(x)) Return FP8(act_lu(x))
Input shape: (N, 1, H) for non-gated activations
(N, 2, H) for gated activations
""" """
return GeluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype) act_type_id = ActivationEnum[activation_type]
return ActLuFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype,
act_enum = act_type_id)
class DGeluDBiasCastTransposePrimitive(BasePrimitive): class DActLuDBiasCastTransposePrimitive(BasePrimitive):
""" """
DGelu DBias Cast Transpose Primitive DActLu DBias Cast Transpose Primitive
""" """
name = "te_dgelu_dbias_cast_transpose" name = "te_dact_lu_dbias_cast_transpose"
multiple_results = True multiple_results = True
# out_dtype, static_axis_boundary, transpose_axis_boundary # out_dtype, static_axis_boundary, transpose_axis_boundary, act_enum
impl_static_args = (5, 6, 7) impl_static_args = (5, 6, 7, 8)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
static_axis_boundary, transpose_axis_boundary): static_axis_boundary, transpose_axis_boundary,
act_enum): # pylint: disable=unused-argument
""" """
te_dgelu_dbais_cast_transpose_p abstract te_dact_lu_dbais_cast_transpose_p abstract
""" """
dtype = dtypes.canonicalize_dtype(dz_aval.dtype) dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
...@@ -4123,7 +3936,8 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive): ...@@ -4123,7 +3936,8 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
ir_hidden_szie = dz_aval.shape[-1] ir_hidden_szie = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1] gi_hidden_size = x_aval.shape[-1]
assert ir_hidden_szie == gi_hidden_size assert ir_hidden_szie == gi_hidden_size
t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, transpose_axis_boundary) t_shape = _multidim_transpose(x_aval.shape,
static_axis_boundary, transpose_axis_boundary)
out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype) out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
...@@ -4146,18 +3960,18 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive): ...@@ -4146,18 +3960,18 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
@staticmethod @staticmethod
def outer_abstract(*args, **kwargs): def outer_abstract(*args, **kwargs):
""" """
te_dgelu_dbais_cast_transpose_p outer abstract te_dact_lu_dbais_cast_transpose_p outer abstract
""" """
out, t_out, dbias, updated_amax_aval, _ = \ out, t_out, dbias, updated_amax_aval, _ = \
DGeluDBiasCastTransposePrimitive.abstract(*args, **kwargs) DActLuDBiasCastTransposePrimitive.abstract(*args, **kwargs)
return out, t_out, dbias, updated_amax_aval return out, t_out, dbias, updated_amax_aval
@staticmethod @staticmethod
def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary,
transpose_axis_boundary): transpose_axis_boundary, act_enum):
""" """
te_dgated_gelu_cast_transpose_p lowering rules te_dgated_act_lu_cast_transpose_p lowering rules
""" """
dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
...@@ -4169,11 +3983,11 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive): ...@@ -4169,11 +3983,11 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
ir_dz_shape = ir_dz_type.shape ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type) x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape x_shape = x_type.shape
assert ir_dz_shape == x_shape dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
x_batch_size = reduce(operator.mul, x_shape[:-2])
batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) assert dz_batch_szie == x_batch_size
ir_hidden_szie = ir_dz_shape[-1] ir_hidden_szie = ir_dz_shape[-1]
contracted_x_shape = (batch_szie, ir_hidden_szie) contracted_x_shape = (x_batch_size, ir_hidden_szie)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type) ir_amax_type = ir.RankedTensorType(amax.type)
...@@ -4199,9 +4013,10 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive): ...@@ -4199,9 +4013,10 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor( opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_x_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype), contracted_x_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype)) jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
act_enum)
out = custom_caller(DGeluDBiasCastTransposePrimitive.name, out = custom_caller(DActLuDBiasCastTransposePrimitive.name,
args, args,
opaque, opaque,
False, False,
...@@ -4211,12 +4026,12 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive): ...@@ -4211,12 +4026,12 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
@staticmethod @staticmethod
def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary, def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary,
transpose_axis_boundary): transpose_axis_boundary, act_enum):
""" """
to describe implementation to describe implementation
""" """
assert DGeluDBiasCastTransposePrimitive.inner_primitive is not None assert DActLuDBiasCastTransposePrimitive.inner_primitive is not None
out, t_out, dbias, updated_amax, _ = DGeluDBiasCastTransposePrimitive.inner_primitive.bind( out, t_out, dbias, updated_amax, _ = DActLuDBiasCastTransposePrimitive.inner_primitive.bind(
dz, dz,
x, x,
amax, amax,
...@@ -4224,18 +4039,19 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive): ...@@ -4224,18 +4039,19 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
scale_inv, scale_inv,
out_dtype=out_dtype, out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary, static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary) transpose_axis_boundary=transpose_axis_boundary,
act_enum=act_enum)
return out, t_out, dbias, updated_amax return out, t_out, dbias, updated_amax
@staticmethod @staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary,
transpose_axis_boundary): transpose_axis_boundary, act_enum):
""" """
to describe batch rules for vmap to describe batch rules for vmap
""" """
del static_axis_boundary del static_axis_boundary
_check_valid_batch_dims(batch_dims) _check_valid_batch_dims(batch_dims)
assert DGeluDBiasCastTransposePrimitive.outer_primitive is not None assert DActLuDBiasCastTransposePrimitive.outer_primitive is not None
dz, x, amax, scale, scale_inv = batched_args dz, x, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims x_bdim, _, amax_bdim, _, _ = batch_dims
...@@ -4244,7 +4060,7 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive): ...@@ -4244,7 +4060,7 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
transpose_axis_boundary += 1 # Plus batch dim transpose_axis_boundary += 1 # Plus batch dim
out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
return DGeluDBiasCastTransposePrimitive.outer_primitive.bind( return DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
dz, dz,
x, x,
amax, amax,
...@@ -4252,12 +4068,13 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive): ...@@ -4252,12 +4068,13 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
scale_inv, scale_inv,
out_dtype=out_dtype, out_dtype=out_dtype,
static_axis_boundary=x_bdim, static_axis_boundary=x_bdim,
transpose_axis_boundary=transpose_axis_boundary), out_bdims transpose_axis_boundary=transpose_axis_boundary,
act_enum=act_enum), out_bdims
@staticmethod @staticmethod
def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary,
arg_infos, result_infos): act_enum, mesh, arg_infos, result_infos):
del out_dtype, result_infos del out_dtype, result_infos, act_enum
x_spec = get_padded_spec(arg_infos[1]) x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
...@@ -4268,8 +4085,8 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive): ...@@ -4268,8 +4085,8 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding) return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)
@staticmethod @staticmethod
def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, def partition(out_dtype, static_axis_boundary, transpose_axis_boundary,
result_infos): act_enum, mesh, arg_infos, result_infos):
del result_infos del result_infos
x_spec = get_padded_spec(arg_infos[1]) x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
...@@ -4285,7 +4102,8 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive): ...@@ -4285,7 +4102,8 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
amax_sharding) amax_sharding)
def sharded_impl(dz, x, amax, scale, scale_inv): def sharded_impl(dz, x, amax, scale, scale_inv):
local_out, local_t_out, local_dbias, local_amax = DGeluDBiasCastTransposePrimitive.impl( local_out, local_t_out, local_dbias, local_amax =\
DActLuDBiasCastTransposePrimitive.impl(
dz, dz,
x, x,
amax, amax,
...@@ -4293,7 +4111,8 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive): ...@@ -4293,7 +4111,8 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
scale_inv, scale_inv,
out_dtype=out_dtype, out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary, static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary) transpose_axis_boundary=transpose_axis_boundary,
act_enum=act_enum)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_out, local_t_out, global_dbias, global_updated_amax return local_out, local_t_out, global_dbias, global_updated_amax
...@@ -4301,26 +4120,30 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive): ...@@ -4301,26 +4120,30 @@ class DGeluDBiasCastTransposePrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DGeluDBiasCastTransposePrimitive) register_primitive(DActLuDBiasCastTransposePrimitive)
def dgelu_dbias_cast_transpose( def dact_lu_dbias_cast_transpose(
dz: jnp.ndarray, dz: jnp.ndarray,
x: jnp.ndarray, x: jnp.ndarray,
amax: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, scale_inv: jnp.ndarray,
out_dtype: TEDType, out_dtype: TEDType,
static_axis_boundary: int, static_axis_boundary: int,
transpose_axis_boundary: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: transpose_axis_boundary: int = -1,
activation_type: Sequence[Union[str, Callable]] = ('gelu',)
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
""" """
cast transpose dgelu and dbias fusion wrapper cast transpose dact_lu and dbias fusion wrapper
Return FP8(dgeglu(inputs)), dbias Return FP8(dact_lu(inputs)), dbias
ONLY support non-gated activation type
""" """
if static_axis_boundary < 0: if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes static_axis_boundary = -1 # means no static axes
return DGeluDBiasCastTransposePrimitive.outer_primitive.bind( act_type_id = ActivationEnum[activation_type]
return DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
dz, dz,
x, x,
amax, amax,
...@@ -4328,7 +4151,8 @@ def dgelu_dbias_cast_transpose( ...@@ -4328,7 +4151,8 @@ def dgelu_dbias_cast_transpose(
scale_inv, scale_inv,
out_dtype=out_dtype, out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary, static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary) transpose_axis_boundary=transpose_axis_boundary,
act_enum=act_type_id)
class DBiasCastTransposePrimitive(BasePrimitive): class DBiasCastTransposePrimitive(BasePrimitive):
...@@ -4353,13 +4177,11 @@ class DBiasCastTransposePrimitive(BasePrimitive): ...@@ -4353,13 +4177,11 @@ class DBiasCastTransposePrimitive(BasePrimitive):
assert amax_aval.dtype == jnp.float32 assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32
gi_hidden_size = dz_aval.shape[-1] gi_hidden_size = reduce(operator.mul, dz_aval.shape[transpose_axis_boundary:])
t_shape = _multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary) t_shape = _multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary)
out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype) out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
if dz_aval.shape[-2] == 2:
gi_hidden_size *= 2
dbias_shape = (*dz_aval.shape[:static_axis_boundary + 1], gi_hidden_size) dbias_shape = (*dz_aval.shape[:static_axis_boundary + 1], gi_hidden_size)
dbias = dz_aval.update(shape=dbias_shape, dtype=dtype) dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)
...@@ -4398,13 +4220,9 @@ class DBiasCastTransposePrimitive(BasePrimitive): ...@@ -4398,13 +4220,9 @@ class DBiasCastTransposePrimitive(BasePrimitive):
assert scale_inv_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32
ir_dz_type = ir.RankedTensorType(dz.type) ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape ir_dz_shape = ir_dz_type.shape
ir_hidden_szie = ir_dz_shape[-1] batch_size = reduce(operator.mul, ir_dz_shape[:transpose_axis_boundary])
if dz_aval.shape[-2] == 2: ir_hidden_size = reduce(operator.mul, ir_dz_shape[transpose_axis_boundary:])
batch_szie = reduce(operator.mul, ir_dz_shape[:-2]) contracted_dz_shape = (batch_size, ir_hidden_size)
ir_hidden_szie *= 2
else:
batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
contracted_dz_shape = (batch_szie, ir_hidden_szie)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type) ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type ir_amax_dtype = ir_amax_type.element_type
...@@ -4413,7 +4231,7 @@ class DBiasCastTransposePrimitive(BasePrimitive): ...@@ -4413,7 +4231,7 @@ class DBiasCastTransposePrimitive(BasePrimitive):
ir_scale_inv_shape = ir_amax_shape ir_scale_inv_shape = ir_amax_shape
transposed_dz_shape = _multidim_transpose(ir_dz_shape, static_axis_boundary, transposed_dz_shape = _multidim_transpose(ir_dz_shape, static_axis_boundary,
transpose_axis_boundary) transpose_axis_boundary)
dbias_shape = (*ir_dz_shape[:static_axis_boundary + 1], ir_hidden_szie) dbias_shape = (*ir_dz_shape[:static_axis_boundary + 1], ir_hidden_size)
wkspace_aval = ctx.avals_out[-1] wkspace_aval = ctx.avals_out[-1]
...@@ -4556,1356 +4374,60 @@ def dbias_cast_transpose( ...@@ -4556,1356 +4374,60 @@ def dbias_cast_transpose(
transpose_axis_boundary=transpose_axis_boundary) transpose_axis_boundary=transpose_axis_boundary)
class GatedGeluFp8Primitive(BasePrimitive): class DgatedActLuCastTransposePrimitive(BasePrimitive):
""" """
Gated Gelu FP8 Primitive Dgated ActLu Cast Transpose Primitive
""" """
name = "te_gated_gelu_fp8" name = "te_dgated_act_lu_cast_transpose"
multiple_results = True multiple_results = True
impl_static_args = (4,) #out_dtype impl_static_args = (5, 6, 7) # out_dtype, static_axis_boundary, act_enum
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype): def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
static_axis_boundary, act_enum): # pylint: disable=unused-argument
""" """
te_gated_gelu_p abstract te_dgated_act_lu_cast_transpose_p abstract
""" """
dtype = dtypes.canonicalize_dtype(x_aval.dtype) dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
# Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert x_aval.shape[-2] == 2 # Linear + GeLU
assert amax_aval.dtype == jnp.float32 assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32
ir_hidden_szie = dz_aval.shape[-1]
assert x_aval.shape[-2] == 2 # Assume x in (....., 2, hidden) gi_hidden_size = x_aval.shape[-1]
hidden_size = x_aval.shape[-1] assert ir_hidden_szie == gi_hidden_size
batch_shape = x_aval.shape[:-2] t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, -2)
out_shape = (batch_shape) + (hidden_size,) out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return out, t_out, updated_amax_aval
return out_aval, updated_amax_aval
@staticmethod @staticmethod
def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype): def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, act_enum):
""" """
te_gated_gelu_p lowering rules te_dgated_act_lu_cast_transpose_p lowering rules
""" """
x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert amax_aval.dtype == jnp.float32 assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32
ir_x_type = ir.RankedTensorType(x.type) ir_dz_type = ir.RankedTensorType(dz.type)
ir_x_shape = ir_x_type.shape ir_dz_shape = ir_dz_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) x_type = ir.RankedTensorType(x.type)
ir_amax_type = ir.RankedTensorType(amax.type) x_shape = x_type.shape
ir_amax_dtype = ir_amax_type.element_type dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
ir_amax_shape = ir_amax_type.shape x_batch_size = reduce(operator.mul, x_shape[:-2])
ir_scale_shape = ir_amax_shape assert dz_batch_szie == x_batch_size
ir_scale_inv_shape = ir_amax_shape assert x_shape[-2] == 2 # Linear + GeLU
ir_hidden_szie = ir_dz_shape[-1]
hidden_size = ir_x_shape[-1] gi_hidden_size = x_shape[-1]
batch_shape = ir_x_shape[:-2] assert ir_hidden_szie == gi_hidden_size
batch_size = reduce(operator.mul, batch_shape)
out_shape = batch_shape + [hidden_size]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, out_shape[-1]),
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype))
out = custom_caller(GatedGeluFp8Primitive.name,
args,
opaque,
False,
operand_output_aliases={1: 1})
return out
@staticmethod
def impl(x, amax, scale, scale_inv, out_dtype):
"""
to describe implementation
"""
assert GatedGeluFp8Primitive.inner_primitive is not None
out, updated_amax = GatedGeluFp8Primitive.inner_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
return out, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype):
"""
to describe batch rules for vmap
"""
_check_valid_batch_dims(batch_dims)
assert GatedGeluFp8Primitive.outer_primitive is not None
x, amax, scale, scale_inv = batched_args
x_bdim, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, amax_bdim
return GatedGeluFp8Primitive.outer_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (out_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (out_sharding, amax_sharding)
def sharded_impl(x, amax, scale, scale_inv):
local_x, local_amax = GatedGeluFp8Primitive.impl(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_x, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(GatedGeluFp8Primitive)
def gated_gelu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
gated gelu wrapper
Return FP8(geglu(x))
"""
return GatedGeluFp8Primitive.outer_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
class DgatedGeluCastTransposePrimitive(BasePrimitive):
"""
Dgated Gelu Cast Transpose Primitive
"""
name = "te_dgated_gelu_cast_transpose"
multiple_results = True
impl_static_args = (5, 6) # out_dtype, static_axis_boundary
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
static_axis_boundary):
"""
te_dgated_gelu_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert x_aval.shape[-2] == 2 # Linear + GeLU
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_hidden_szie = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
assert ir_hidden_szie == gi_hidden_size
t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, -2)
out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return out, t_out, updated_amax_aval
@staticmethod
def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary):
"""
te_dgated_gelu_cast_transpose_p lowering rules
"""
dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
x_batch_size = reduce(operator.mul, x_shape[:-2])
assert dz_batch_szie == x_batch_size
assert x_shape[-2] == 2 # Linear + GeLU
ir_hidden_szie = ir_dz_shape[-1]
gi_hidden_size = x_shape[-1]
assert ir_hidden_szie == gi_hidden_size
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_x_shape = _multidim_transpose(x_shape, static_axis_boundary, -2)
out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [dz, x, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
contracted_x_shape = (x_batch_size, x_shape[-1])
opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype))
out = custom_caller(DgatedGeluCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={2: 2})
return out
@staticmethod
def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary):
"""
to describe implementation
"""
assert DgatedGeluCastTransposePrimitive.inner_primitive is not None
out, t_out, updated_amax = DgatedGeluCastTransposePrimitive.inner_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary)
return out, t_out, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary):
"""
to describe batch rules for vmap
"""
del static_axis_boundary
_check_valid_batch_dims(batch_dims)
assert DgatedGeluCastTransposePrimitive.outer_primitive is not None
dz, x, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, x_bdim, amax_bdim
return DgatedGeluCastTransposePrimitive.outer_primitive.bind(
dz, x, amax, scale, scale_inv, out_dtype=out_dtype,
static_axis_boundary=x_bdim), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, static_axis_boundary, mesh, arg_infos,
result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
return (out_sharding, tranposed_out_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, static_axis_boundary, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)
def sharded_impl(dz, x, amax, scale, scale_inv):
local_out, local_t_out, local_amax = DgatedGeluCastTransposePrimitive.impl(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_out, local_t_out, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DgatedGeluCastTransposePrimitive)
def dgated_gelu_cast_transpose(
dz: jnp.ndarray, x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, out_dtype: TEDType,
static_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose d_gated_gelu fusion wrapper
Return FP8(dgeglu(inputs))
"""
return DgatedGeluCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary)
# Primitives for SwiGLU and SiLU
class SiluPrimitive(BasePrimitive):
"""
Silu Froward Primitive
"""
name = "te_silu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
@staticmethod
def abstract(x_aval):
"""
gated_silu abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_aval = core.raise_to_shaped(x_aval)
return out_aval
@staticmethod
def lowering(ctx, x):
"""
gated_silu lowering rules
"""
(x_aval,) = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
out_shape = ir_x_shape
out_types = [
ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
]
operands = [x]
operand_shapes = [ir_x_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
hidden_size = ir_x_shape[-1]
batch_size = reduce(operator.mul, ir_x_shape[:-1])
in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype,
in_dtype)
out = custom_caller(SiluPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(x):
assert SiluPrimitive.inner_primitive is not None
out = SiluPrimitive.inner_primitive.bind(x)
return out
@staticmethod
def batcher(batched_args, batch_dims):
"""
gated_silu batcher
"""
_check_valid_batch_dims(batch_dims)
assert SiluPrimitive.outer_primitive is not None
inputs, = batched_args
inputs_bdim, = batch_dims
out_bdims = inputs_bdim
return SiluPrimitive.outer_primitive.bind(inputs), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
"""
gated_silu infer_sharding_from_operands
"""
del result_infos # Unused.
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
return out_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
"""
gated_silu partitioning
"""
del result_infos
x_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
impl = SiluPrimitive.impl
return mesh, impl, out_sharding, arg_shardings
register_primitive(SiluPrimitive)
def silu(inputs: jnp.ndarray) -> jnp.ndarray:
"""
silu wrapper
Return geglu(inputs)
Assume inputs has two dimensions shape and the memory layout is (N..., H)
"""
return SiluPrimitive.outer_primitive.bind(inputs)
class DSiluPrimitive(BasePrimitive):
"""
Dgated Silu Primitive
"""
name = "te_dsilu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
@staticmethod
def abstract(dz_aval, x_aval):
"""
dsilu abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert dz_aval.shape == x_aval.shape
out_aval = core.raise_to_shaped(x_aval)
return out_aval
@staticmethod
def lowering(ctx, dz, x):
"""
dsilu lowering rules
"""
in_aval, gi_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gi_aval.dtype == in_aval.dtype
ir_in_type = ir.RankedTensorType(dz.type)
ir_in_shape = ir_in_type.shape
gi_type = ir.RankedTensorType(x.type)
gi_shape = gi_type.shape
assert ir_in_shape == gi_shape
ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
i_hidden_size = ir_in_shape[-1]
out_dtype = ir_in_type.element_type
out_shape = gi_shape
out_types = [
ir.RankedTensorType.get(out_shape, out_dtype),
]
operands = [dz, x]
operand_shapes = [ir_in_shape, gi_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size),
in_dtype, in_dtype)
out = custom_caller(DSiluPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(dz, x):
"""
dsilu implementation
"""
assert DSiluPrimitive.inner_primitive is not None
dx = DSiluPrimitive.inner_primitive.bind(dz, x)
return dx
@staticmethod
def batcher(batched_args, batch_dims):
"""
dsilu batcher
"""
_check_valid_batch_dims(batch_dims)
assert DSiluPrimitive.outer_primitive is not None
dz, x = batched_args
_, x_bdim = batch_dims
out_bdims = x_bdim
return DSiluPrimitive.outer_primitive.bind(dz, x), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
"""
dsilu infer_sharding_from_operands
"""
del result_infos # Unused.
silu_out_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*silu_out_spec))
return dx_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
"""
dsilu partition
"""
del result_infos
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding
impl = DSiluPrimitive.impl
return mesh, impl, out_shardings, arg_shardings
register_primitive(DSiluPrimitive)
def dsilu(inputs: jnp.ndarray, silu_inputs: jnp.ndarray) -> jnp.ndarray:
"""
dsilu fusion wrapper
Return dgeglu(inputs)
"""
return DSiluPrimitive.outer_primitive.bind(inputs, silu_inputs)
class GatedSiluPrimitive(BasePrimitive):
"""
Gated Silu Froward Primitive
"""
name = "te_gated_silu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
@staticmethod
def abstract(x_aval):
"""
gated_silu abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
x_shape = x_aval.shape
assert x_shape[-2] == 2 # Assume x in (....., 2, hidden)
hidden_size = x_shape[-1]
batch_shapes = x_shape[:-2]
x_shape = x_aval.shape
out_aval = core.raise_to_shaped(x_aval)
out_shape = (batch_shapes) + (hidden_size,)
out_aval = out_aval.update(shape=out_shape, dtype=dtype)
return out_aval
@staticmethod
def lowering(ctx, x):
"""
gated_silu lowering rules
"""
(x_aval,) = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]]
out_types = [
ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
]
operands = [x]
operand_shapes = [ir_x_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
hidden_size = ir_x_shape[-1]
batch_size = reduce(operator.mul, ir_x_shape[:-2])
in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype,
in_dtype)
out = custom_caller(GatedSiluPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(x):
assert GatedSiluPrimitive.inner_primitive is not None
out = GatedSiluPrimitive.inner_primitive.bind(x)
return out
@staticmethod
def batcher(batched_args, batch_dims):
"""
gated_silu batcher
"""
_check_valid_batch_dims(batch_dims)
assert GatedSiluPrimitive.outer_primitive is not None
inputs, = batched_args
inputs_bdim, = batch_dims
out_bdims = inputs_bdim
return GatedSiluPrimitive.outer_primitive.bind(inputs), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
"""
gated_silu infer_sharding_from_operands
"""
del result_infos # Unused.
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
return out_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
"""
gated_silu partitioning
"""
del result_infos
x_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
impl = GatedSiluPrimitive.impl
return mesh, impl, out_sharding, arg_shardings
register_primitive(GatedSiluPrimitive)
def gated_silu(inputs: jnp.ndarray) -> jnp.ndarray:
"""
gated silu wrapper
Return FP8(geglu(inputs))
Assume inputs has two dimensions shape and the memory layout is (N, 2, H)
"""
return GatedSiluPrimitive.outer_primitive.bind(inputs)
class DgatedSiluPrimitive(BasePrimitive):
"""
Dgated Silu Primitive
"""
name = "te_dgated_silu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
@staticmethod
def abstract(dz_aval, x_aval):
"""
dgated_silu abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
for axis in range(len(dz_aval.shape) - 1):
assert dz_aval.shape[axis] == x_aval.shape[axis]
assert x_aval.shape[-2] == 2 # Assume x in (....., 2, hidden)
i_hidden_size = dz_aval.shape[-1]
g_hidden_size = x_aval.shape[-1]
assert i_hidden_size == g_hidden_size
out_aval = core.raise_to_shaped(x_aval)
return out_aval
@staticmethod
def lowering(ctx, dz, x):
"""
dgated_silu lowering rules
"""
in_aval, gi_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gi_aval.dtype == in_aval.dtype
ir_in_type = ir.RankedTensorType(dz.type)
ir_in_shape = ir_in_type.shape
gi_type = ir.RankedTensorType(x.type)
gi_shape = gi_type.shape
for axis in range(len(ir_in_shape) - 1):
assert ir_in_shape[axis] == gi_shape[axis]
ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
i_hidden_size = ir_in_shape[-1]
g_hidden_size = gi_shape[-1]
assert i_hidden_size == g_hidden_size
out_dtype = ir_in_type.element_type
out_shape = gi_shape
out_types = [
ir.RankedTensorType.get(out_shape, out_dtype),
]
operands = [dz, x]
operand_shapes = [ir_in_shape, gi_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size),
in_dtype, in_dtype)
out = custom_caller(DgatedSiluPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(dz, x):
"""
dgated_silu implementation
"""
assert DgatedSiluPrimitive.inner_primitive is not None
dx = DgatedSiluPrimitive.inner_primitive.bind(dz, x)
return dx
@staticmethod
def batcher(batched_args, batch_dims):
"""
dgated_silu batcher
"""
_check_valid_batch_dims(batch_dims)
assert DgatedSiluPrimitive.outer_primitive is not None
dz, x = batched_args
_, x_bdim = batch_dims
out_bdims = x_bdim
return DgatedSiluPrimitive.outer_primitive.bind(dz, x), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
"""
dgated_silu infer_sharding_from_operands
"""
del result_infos # Unused.
silu_out_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*silu_out_spec))
return dx_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
"""
dgated_silu partition
"""
del result_infos
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding
impl = DgatedSiluPrimitive.impl
return mesh, impl, out_shardings, arg_shardings
register_primitive(DgatedSiluPrimitive)
def dgated_silu(inputs: jnp.ndarray, silu_inputs: jnp.ndarray) -> jnp.ndarray:
"""
dgated_silu fusion wrapper
Return dgeglu(inputs)
"""
return DgatedSiluPrimitive.outer_primitive.bind(inputs, silu_inputs)
class SiluFp8Primitive(BasePrimitive):
"""
Silu FP8 Primitive
"""
name = "te_silu_fp8"
multiple_results = True
impl_static_args = (4,) #out_dtype
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype):
"""
te_silu_p abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
# Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return out_aval, updated_amax_aval
@staticmethod
def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
"""
te_gated_silu_p lowering rules
"""
x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
hidden_size = ir_x_shape[-1]
batch_size = reduce(operator.mul, ir_x_shape[:-1])
out_shape = ir_x_shape
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size),
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype))
out = custom_caller(SiluFp8Primitive.name,
args,
opaque,
False,
operand_output_aliases={1: 1})
return out
@staticmethod
def impl(x, amax, scale, scale_inv, out_dtype):
"""
to describe implementation
"""
assert SiluFp8Primitive.inner_primitive is not None
out, updated_amax = SiluFp8Primitive.inner_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
return out, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype):
"""
to describe batch rules for vmap
"""
_check_valid_batch_dims(batch_dims)
assert SiluFp8Primitive.outer_primitive is not None
x, amax, scale, scale_inv = batched_args
x_bdim, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, amax_bdim
return SiluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv,
out_dtype=out_dtype), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (out_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (out_sharding, amax_sharding)
def sharded_impl(x, amax, scale, scale_inv):
local_x, local_amax = SiluFp8Primitive.impl(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_x, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(SiluFp8Primitive)
def silu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
gated silu wrapper
Return FP8(geglu(x))
"""
return SiluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype)
class DSiluDBiasCastTransposePrimitive(BasePrimitive):
"""
DSilu DBias Cast Transpose Primitive
"""
name = "te_dsilu_dbias_cast_transpose"
multiple_results = True
# out_dtype, static_axis_boundary, transpose_axis_boundary
impl_static_args = (5, 6, 7)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
static_axis_boundary, transpose_axis_boundary):
"""
te_dsilu_dbais_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_hidden_szie = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
assert ir_hidden_szie == gi_hidden_size
t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, transpose_axis_boundary)
out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
dbias_shape = (*x_aval.shape[:static_axis_boundary + 1], gi_hidden_size)
dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
wkspace_info, = transformer_engine_jax.get_dact_dbias_ct_workspace_sizes(
x_aval.size // gi_hidden_size,
gi_hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
)
wkspace_aval = x_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
return out, t_out, dbias, updated_amax_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
te_dsilu_dbais_cast_transpose_p outer abstract
"""
out, t_out, dbias, updated_amax_aval, _ = \
DSiluDBiasCastTransposePrimitive.abstract(*args, **kwargs)
return out, t_out, dbias, updated_amax_aval
@staticmethod
def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
te_dgated_silu_cast_transpose_p lowering rules
"""
dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
assert ir_dz_shape == x_shape
batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
ir_hidden_szie = ir_dz_shape[-1]
contracted_x_shape = (batch_szie, ir_hidden_szie)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_x_shape = _multidim_transpose(x_shape, static_axis_boundary,
transpose_axis_boundary)
dbias_shape = (*x_shape[:static_axis_boundary + 1], ir_hidden_szie)
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
]
operands = [dz, x, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_x_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype))
out = custom_caller(DSiluDBiasCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={2: 3})
return out
@staticmethod
def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
to describe implementation
"""
assert DSiluDBiasCastTransposePrimitive.inner_primitive is not None
out, t_out, dbias, updated_amax, _ = DSiluDBiasCastTransposePrimitive.inner_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
return out, t_out, dbias, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
to describe batch rules for vmap
"""
del static_axis_boundary
_check_valid_batch_dims(batch_dims)
assert DSiluDBiasCastTransposePrimitive.outer_primitive is not None
dz, x, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims
# Minus batch dim.
transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
transpose_axis_boundary += 1 # Plus batch dim
out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
return DSiluDBiasCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=x_bdim,
transpose_axis_boundary=transpose_axis_boundary), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh,
arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)
@staticmethod
def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding,
amax_sharding)
def sharded_impl(dz, x, amax, scale, scale_inv):
local_out, local_t_out, local_dbias, local_amax = DSiluDBiasCastTransposePrimitive.impl(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_out, local_t_out, global_dbias, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DSiluDBiasCastTransposePrimitive)
def dsilu_dbias_cast_transpose(
dz: jnp.ndarray,
x: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType,
static_axis_boundary: int,
transpose_axis_boundary: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose dsilu and dbias fusion wrapper
Return FP8(dgeglu(inputs)), dbias
"""
if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes
return DSiluDBiasCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
class GatedSiluFp8Primitive(BasePrimitive):
"""
Gated Silu FP8 Primitive
"""
name = "te_gated_silu_fp8"
multiple_results = True
impl_static_args = (4,) #out_dtype
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype):
"""
te_gated_silu_p abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
# Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
assert x_aval.shape[-2] == 2 # Assume x in (....., 2, hidden)
hidden_size = x_aval.shape[-1]
batch_shape = x_aval.shape[:-2]
out_shape = (batch_shape) + (hidden_size,)
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return out_aval, updated_amax_aval
@staticmethod
def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
"""
te_gated_silu_p lowering rules
"""
x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
hidden_size = ir_x_shape[-1]
batch_shape = ir_x_shape[:-2]
batch_size = reduce(operator.mul, batch_shape)
out_shape = batch_shape + [hidden_size]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, out_shape[-1]),
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype))
out = custom_caller(GatedSiluFp8Primitive.name,
args,
opaque,
False,
operand_output_aliases={1: 1})
return out
@staticmethod
def impl(x, amax, scale, scale_inv, out_dtype):
"""
to describe implementation
"""
assert GatedSiluFp8Primitive.inner_primitive is not None
out, updated_amax = GatedSiluFp8Primitive.inner_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
return out, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype):
"""
to describe batch rules for vmap
"""
_check_valid_batch_dims(batch_dims)
assert GatedSiluFp8Primitive.outer_primitive is not None
x, amax, scale, scale_inv = batched_args
x_bdim, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, amax_bdim
return GatedSiluFp8Primitive.outer_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (out_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (out_sharding, amax_sharding)
def sharded_impl(x, amax, scale, scale_inv):
local_x, local_amax = GatedSiluFp8Primitive.impl(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_x, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(GatedSiluFp8Primitive)
def gated_silu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
gated silu wrapper
Return FP8(geglu(x))
"""
return GatedSiluFp8Primitive.outer_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
class DgatedSiluCastTransposePrimitive(BasePrimitive):
"""
Dgated Silu Cast Transpose Primitive
"""
name = "te_dgated_silu_cast_transpose"
multiple_results = True
impl_static_args = (5, 6) # out_dtype, static_axis_boundary
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
static_axis_boundary):
"""
te_dgated_silu_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert x_aval.shape[-2] == 2 # Linear + GeLU
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_hidden_szie = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
assert ir_hidden_szie == gi_hidden_size
t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, -2)
out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return out, t_out, updated_amax_aval
@staticmethod
def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary):
"""
te_dgated_silu_cast_transpose_p lowering rules
"""
dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
x_batch_size = reduce(operator.mul, x_shape[:-2])
assert dz_batch_szie == x_batch_size
assert x_shape[-2] == 2 # Linear + GeLU
ir_hidden_szie = ir_dz_shape[-1]
gi_hidden_size = x_shape[-1]
assert ir_hidden_szie == gi_hidden_size
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type) ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type ir_amax_dtype = ir_amax_type.element_type
...@@ -5922,11 +4444,13 @@ class DgatedSiluCastTransposePrimitive(BasePrimitive): ...@@ -5922,11 +4444,13 @@ class DgatedSiluCastTransposePrimitive(BasePrimitive):
operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
contracted_x_shape = (x_batch_size, x_shape[-1]) contracted_x_shape = (x_batch_size, x_shape[-1])
opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape, opaque = transformer_engine_jax.pack_common_descriptor(
jax_dtype_to_te_dtype(dz_aval.dtype), contracted_x_shape,
jax_dtype_to_te_dtype(out_dtype)) jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
act_enum)
out = custom_caller(DgatedSiluCastTransposePrimitive.name, out = custom_caller(DgatedActLuCastTransposePrimitive.name,
args, args,
opaque, opaque,
False, False,
...@@ -5935,41 +4459,43 @@ class DgatedSiluCastTransposePrimitive(BasePrimitive): ...@@ -5935,41 +4459,43 @@ class DgatedSiluCastTransposePrimitive(BasePrimitive):
return out return out
@staticmethod @staticmethod
def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary): def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary, act_enum):
""" """
to describe implementation to describe implementation
""" """
assert DgatedSiluCastTransposePrimitive.inner_primitive is not None assert DgatedActLuCastTransposePrimitive.inner_primitive is not None
out, t_out, updated_amax = DgatedSiluCastTransposePrimitive.inner_primitive.bind( out, t_out, updated_amax = DgatedActLuCastTransposePrimitive.inner_primitive.bind(
dz, dz,
x, x,
amax, amax,
scale, scale,
scale_inv, scale_inv,
out_dtype=out_dtype, out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary) static_axis_boundary=static_axis_boundary,
act_enum=act_enum)
return out, t_out, updated_amax return out, t_out, updated_amax
@staticmethod @staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary): def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum):
""" """
to describe batch rules for vmap to describe batch rules for vmap
""" """
del static_axis_boundary del static_axis_boundary
_check_valid_batch_dims(batch_dims) _check_valid_batch_dims(batch_dims)
assert DgatedSiluCastTransposePrimitive.outer_primitive is not None assert DgatedActLuCastTransposePrimitive.outer_primitive is not None
dz, x, amax, scale, scale_inv = batched_args dz, x, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims x_bdim, _, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, x_bdim, amax_bdim out_bdims = x_bdim, x_bdim, amax_bdim
return DgatedSiluCastTransposePrimitive.outer_primitive.bind( return DgatedActLuCastTransposePrimitive.outer_primitive.bind(
dz, x, amax, scale, scale_inv, out_dtype=out_dtype, dz, x, amax, scale, scale_inv, out_dtype=out_dtype,
static_axis_boundary=x_bdim), out_bdims static_axis_boundary=x_bdim,
act_enum=act_enum), out_bdims
@staticmethod @staticmethod
def infer_sharding_from_operands(out_dtype, static_axis_boundary, mesh, arg_infos, def infer_sharding_from_operands(out_dtype, static_axis_boundary, act_enum,
result_infos): mesh, arg_infos, result_infos):
del out_dtype, result_infos del out_dtype, result_infos, act_enum
x_spec = get_padded_spec(arg_infos[1]) x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2) xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
...@@ -5978,7 +4504,8 @@ class DgatedSiluCastTransposePrimitive(BasePrimitive): ...@@ -5978,7 +4504,8 @@ class DgatedSiluCastTransposePrimitive(BasePrimitive):
return (out_sharding, tranposed_out_sharding, amax_sharding) return (out_sharding, tranposed_out_sharding, amax_sharding)
@staticmethod @staticmethod
def partition(out_dtype, static_axis_boundary, mesh, arg_infos, result_infos): def partition(out_dtype, static_axis_boundary, act_enum,
mesh, arg_infos, result_infos):
del result_infos del result_infos
x_spec = get_padded_spec(arg_infos[1]) x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
...@@ -5990,36 +4517,41 @@ class DgatedSiluCastTransposePrimitive(BasePrimitive): ...@@ -5990,36 +4517,41 @@ class DgatedSiluCastTransposePrimitive(BasePrimitive):
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding) out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)
def sharded_impl(dz, x, amax, scale, scale_inv): def sharded_impl(dz, x, amax, scale, scale_inv):
local_out, local_t_out, local_amax = DgatedSiluCastTransposePrimitive.impl( local_out, local_t_out, local_amax = DgatedActLuCastTransposePrimitive.impl(
dz, dz,
x, x,
amax, amax,
scale, scale,
scale_inv, scale_inv,
out_dtype=out_dtype, out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary) static_axis_boundary=static_axis_boundary,
act_enum=act_enum)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_out, local_t_out, global_updated_amax return local_out, local_t_out, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DgatedSiluCastTransposePrimitive) register_primitive(DgatedActLuCastTransposePrimitive)
def dgated_silu_cast_transpose( def dgated_act_lu_cast_transpose(
dz: jnp.ndarray, x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, dz: jnp.ndarray, x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, out_dtype: TEDType, scale_inv: jnp.ndarray, out_dtype: TEDType,
static_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: static_axis_boundary: int,
activation_type: Sequence[Union[str, Callable]]
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
""" """
cast transpose d_gated_silu fusion wrapper cast transpose d_gated_act_lu fusion wrapper
Return FP8(dgeglu(inputs)) Return FP8(dgated_act_lu(inputs))
""" """
return DgatedSiluCastTransposePrimitive.outer_primitive.bind( act_type_id = ActivationEnum[activation_type]
return DgatedActLuCastTransposePrimitive.outer_primitive.bind(
dz, dz,
x, x,
amax, amax,
scale, scale,
scale_inv, scale_inv,
out_dtype=out_dtype, out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary) static_axis_boundary=static_axis_boundary,
act_enum=act_type_id)
...@@ -25,25 +25,14 @@ pybind11::dict Registrations() { ...@@ -25,25 +25,14 @@ pybind11::dict Registrations() {
pybind11::dict dict; pybind11::dict dict;
dict["te_transpose"] = EncapsulateFunction(Transpose); dict["te_transpose"] = EncapsulateFunction(Transpose);
dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose); dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose);
dict["te_gelu"] = EncapsulateFunction(Gelu);
dict["te_gelu_fp8"] = EncapsulateFunction(GeluFP8); dict["te_act_lu"] = EncapsulateFunction(ActLu);
dict["te_dgelu"] = EncapsulateFunction(DGelu); dict["te_act_lu_fp8"] = EncapsulateFunction(ActLuFP8);
dict["te_dgelu_dbias_cast_transpose"] = EncapsulateFunction(DGeluDBiasCastTranspose); dict["te_dact_lu"] = EncapsulateFunction(DActLu);
dict["te_dbias_cast_transpose"] = EncapsulateFunction(DBiasCastTranspose); dict["te_dbias_cast_transpose"] = EncapsulateFunction(DBiasCastTranspose);
dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu); dict["te_dact_lu_dbias_cast_transpose"] = EncapsulateFunction(DActLuDBiasCastTranspose);
dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8); dict["te_dgated_act_lu_cast_transpose"] = EncapsulateFunction(DGatedActLuCastTranspose);
dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu);
dict["te_dgated_gelu_cast_transpose"] = EncapsulateFunction(DGatedGeluCastTranspose);
// TODO
dict["te_silu"] = EncapsulateFunction(Silu);
dict["te_silu_fp8"] = EncapsulateFunction(SiluFP8);
dict["te_dsilu"] = EncapsulateFunction(DSilu);
dict["te_dsilu_dbias_cast_transpose"] = EncapsulateFunction(DSiluDBiasCastTranspose);
dict["te_gated_silu"] = EncapsulateFunction(GatedSilu);
dict["te_gated_silu_fp8"] = EncapsulateFunction(GatedSiluFP8);
dict["te_dgated_silu"] = EncapsulateFunction(DGatedSilu);
dict["te_dgated_silu_cast_transpose"] = EncapsulateFunction(DGatedSiluCastTranspose);
//
dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward); dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward);
dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8); dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8);
dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward); dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward);
...@@ -67,8 +56,11 @@ pybind11::dict Registrations() { ...@@ -67,8 +56,11 @@ pybind11::dict Registrations() {
PYBIND11_MODULE(transformer_engine_jax, m) { PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("registrations", &Registrations); m.def("registrations", &Registrations);
m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor); m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor,
m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor); pybind11::arg(), pybind11::arg(), pybind11::arg(), pybind11::arg("act_num") = 0);
m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor,
pybind11::arg(), pybind11::arg(), pybind11::arg(),
pybind11::arg(), pybind11::arg(), pybind11::arg("act_num") = 0);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
...@@ -109,6 +101,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -109,6 +101,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD);
pybind11::enum_<NVTE_Activation_Enum>(m, "NVTE_Activation_Enum", pybind11::module_local())
.value("GELU", NVTE_Activation_Enum::GELU)
.value("GEGLU", NVTE_Activation_Enum::GEGLU)
.value("SILU", NVTE_Activation_Enum::SILU)
.value("SWIGLU", NVTE_Activation_Enum::SWIGLU);
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
......
...@@ -37,6 +37,19 @@ std::vector<size_t> MakeShapeVector(NVTEShape shape) { ...@@ -37,6 +37,19 @@ std::vector<size_t> MakeShapeVector(NVTEShape shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim); return std::vector<size_t>(shape.data, shape.data + shape.ndim);
} }
size_t get_activation_len(NVTE_Activation_Enum act_enum) {
switch (act_enum) {
case NVTE_Activation_Enum::GELU: return 1;
case NVTE_Activation_Enum::GEGLU: return 2;
case NVTE_Activation_Enum::SILU: return 1;
case NVTE_Activation_Enum::SWIGLU: return 2;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
return -1;
}
}
template <typename T> template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) { pybind11::bytes PackOpaque(const T &descriptor) {
auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T)); auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
...@@ -52,23 +65,26 @@ const T *UnpackOpaque(const char *opaque, size_t opaque_len) { ...@@ -52,23 +65,26 @@ const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
} }
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype, pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype) { DType out_dtype, size_t act_enum) {
CustomCallCommonDescriptor desc; CustomCallCommonDescriptor desc;
desc.shape.from_vector(shape); desc.shape.from_vector(shape);
desc.in_dtype = in_dtype; desc.in_dtype = in_dtype;
desc.out_dtype = out_dtype; desc.out_dtype = out_dtype;
desc.act_enum = act_enum;
return PackOpaque(desc); return PackOpaque(desc);
} }
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape, pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype, const std::vector<size_t> &wkshape, DType in_dtype,
DType out_dtype, DType wk_dtype) { DType out_dtype, DType wk_dtype,
size_t act_enum) {
CustomCallCommonWkDescriptor desc; CustomCallCommonWkDescriptor desc;
desc.shape.from_vector(shape); desc.shape.from_vector(shape);
desc.wkshape.from_vector(wkshape); desc.wkshape.from_vector(wkshape);
desc.in_dtype = in_dtype; desc.in_dtype = in_dtype;
desc.out_dtype = out_dtype; desc.out_dtype = out_dtype;
desc.wk_dtype = wk_dtype; desc.wk_dtype = wk_dtype;
desc.act_enum = act_enum;
return PackOpaque(desc); return PackOpaque(desc);
} }
...@@ -170,31 +186,50 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size ...@@ -170,31 +186,50 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
input_cast_trans_tensor.data(), stream); input_cast_trans_tensor.data(), stream);
} }
void GeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output) { cudaStream_t stream, float *scale_inverse, float *amax, void *output,
auto input_shape = std::vector<size_t>{m, n}; NVTE_Activation_Enum act_enum) {
auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape,
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype)); static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape,
auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax, static_cast<DType>(out_dtype), amax,
scale, scale_inverse); scale, scale_inverse);
switch (act_enum) {
nvte_gelu(input_tensor.data(), output_tensor.data(), stream); case NVTE_Activation_Enum::GELU:
nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Enum::GEGLU:
nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Enum::SILU:
nvte_swish(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Enum::SWIGLU:
nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
} }
void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
auto *output = buffers[1]; auto *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);;
GeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output); ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream,
nullptr, nullptr, output, act_enum);
} }
void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]); float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]); float *scale = reinterpret_cast<float *>(buffers[2]);
...@@ -211,107 +246,91 @@ void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opa ...@@ -211,107 +246,91 @@ void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opa
} }
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);;
GeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out, ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream,
output); scale_inv, amax_out, output, act_enum);
} }
void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
auto *gelu_input = buffers[1]; auto *act_input = buffers[1];
auto *output = buffers[2]; auto *output = buffers[2];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);;
auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
auto gelu_input_shape = std::vector<size_t>{m, n}; auto act_input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n * act_len};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype); auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype); auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
nvte_dgelu(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), stream); switch (act_enum) {
} case NVTE_Activation_Enum::GELU:
nvte_dgelu(input_tensor.data(), act_input_tensor.data(),
void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, output_tensor.data(), stream);
size_t opaque_len) { break;
auto *input = buffers[0]; case NVTE_Activation_Enum::GEGLU:
auto *gelu_input = buffers[1]; nvte_dgeglu(input_tensor.data(), act_input_tensor.data(),
float *amax = reinterpret_cast<float *>(buffers[2]); output_tensor.data(), stream);
float *scale = reinterpret_cast<float *>(buffers[3]); break;
float *scale_inv = reinterpret_cast<float *>(buffers[4]); case NVTE_Activation_Enum::SILU:
auto *output = buffers[5]; nvte_dswish(input_tensor.data(), act_input_tensor.data(),
auto *output_trans = buffers[6]; output_tensor.data(), stream);
auto *dbias = buffers[7]; break;
float *amax_out = reinterpret_cast<float *>(buffers[8]); case NVTE_Activation_Enum::SWIGLU:
void *workspace_ptr = buffers[9]; nvte_dswiglu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len); break;
assert(amax == amax_out); default:
if (!use_fp8(desc.out_dtype)) { NVTE_ERROR("Unsupported ActivationEnum");
scale = nullptr; break;
scale_inv = nullptr;
amax_out = nullptr;
} }
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto gelu_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), gelu_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
} }
// HERE pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) { DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size}; auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size}; auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size}; auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
TensorWrapper dummy_workspace; TensorWrapper dummy_workspace;
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), // For now, all dbias_dact(-s) have the same workspace size
output_trans_tensor.data(), dbias_tensor.data(), nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(),
dummy_workspace.data(), nullptr); output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape()); auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
} }
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) { size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]); auto *act_input = buffers[1];
float *scale = reinterpret_cast<float *>(buffers[2]); float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]); float *scale = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4]; float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output_trans = buffers[5]; auto *output = buffers[5];
auto *dbias = buffers[6]; auto *output_trans = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]); auto *dbias = buffers[7];
void *workspace_ptr = buffers[8]; float *amax_out = reinterpret_cast<float *>(buffers[8]);
void *workspace_ptr = buffers[9];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out); assert(amax == amax_out);
...@@ -322,12 +341,15 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -322,12 +341,15 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
} }
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);;
auto input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m}; auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n}; auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor = auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor = auto output_trans_tensor =
...@@ -336,81 +358,27 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -336,81 +358,27 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), switch (act_enum) {
output_trans_tensor.data(), dbias_tensor.data(), case NVTE_Activation_Enum::GELU:
workspace.data(), stream); nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
} output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
void GatedGeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, break;
cudaStream_t stream, float *scale_inverse, float *amax, void *output) { case NVTE_Activation_Enum::SILU:
auto input_shape = std::vector<size_t>{m, n * 2}; nvte_cast_transpose_dbias_dswish(input_tensor.data(), act_input_tensor.data(),
auto output_shape = std::vector<size_t>{m, n}; output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype)); break;
default:
auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax, throw std::runtime_error("Activation Type is not Implemented in DActLuDBiasCastTranspose");
scale, scale_inverse); break;
nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
}
void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
GatedGeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr,
output);
}
void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
} }
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
GatedGeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out,
output);
}
void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *gelu_input = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto gelu_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
nvte_dgeglu(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), stream);
} }
void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) { size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
auto *gelu_input = buffers[1]; auto *act_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]); float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]); float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]); float *scale_inv = reinterpret_cast<float *>(buffers[4]);
...@@ -427,124 +395,69 @@ void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *op ...@@ -427,124 +395,69 @@ void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *op
} }
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Enum>(desc.act_enum);;
auto input_shape = desc.shape.to_vector(); auto input_shape = desc.shape.to_vector();
auto gelu_input_shape = std::vector<size_t>{m, n * 2}; auto act_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2}; auto output_shape = std::vector<size_t>{m, n * 2};
auto output_trans_shape = std::vector<size_t>{n * 2, m}; auto output_trans_shape = std::vector<size_t>{n * 2, m};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype); auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor = auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor = auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
nvte_dgeglu_cast_transpose(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), switch (act_enum) {
output_trans_tensor.data(), stream); case NVTE_Activation_Enum::GEGLU:
} nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
void SiluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, stream);
cudaStream_t stream, float *scale_inverse, float *amax, void *output) { break;
auto input_shape = std::vector<size_t>{m, n}; case NVTE_Activation_Enum::SWIGLU:
auto output_shape = std::vector<size_t>{m, n}; nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype)); stream);
break;
auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax, default:
scale, scale_inverse); NVTE_ERROR("Unsupported ActivationEnum");
break;
nvte_swish(input_tensor.data(), output_tensor.data(), stream);
}
void Silu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
SiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output);
}
void SiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
} }
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
SiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out,
output);
}
void DSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *silu_input = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto silu_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
nvte_dswish(input_tensor.data(), silu_input_tensor.data(), output_tensor.data(), stream);
} }
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) { DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size}; auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size}; auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size}; auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
TensorWrapper dummy_workspace; TensorWrapper dummy_workspace;
// For now, all dbias_dact(-s) have the same workspace size nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), output_trans_tensor.data(), dbias_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), dummy_workspace.data(), nullptr);
dbias_tensor.data(), dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape()); auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
} }
void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) { size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
auto *silu_input = buffers[1]; float *amax = reinterpret_cast<float *>(buffers[1]);
float *amax = reinterpret_cast<float *>(buffers[2]); float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]); float *scale_inv = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]); auto *output = buffers[4];
auto *output = buffers[5]; auto *output_trans = buffers[5];
auto *output_trans = buffers[6]; auto *dbias = buffers[6];
auto *dbias = buffers[7]; float *amax_out = reinterpret_cast<float *>(buffers[7]);
float *amax_out = reinterpret_cast<float *>(buffers[8]); void *workspace_ptr = buffers[8];
void *workspace_ptr = buffers[9];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out); assert(amax == amax_out);
...@@ -556,13 +469,11 @@ void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *op ...@@ -556,13 +469,11 @@ void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *op
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
auto silu_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m}; auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n}; auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype);
auto output_tensor = auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor = auto output_trans_tensor =
...@@ -571,111 +482,9 @@ void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *op ...@@ -571,111 +482,9 @@ void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *op
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_cast_transpose_dbias_dswish(input_tensor.data(), silu_input_tensor.data(), nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), output_trans_tensor.data(), dbias_tensor.data(),
dbias_tensor.data(), workspace.data(), stream); workspace.data(), stream);
}
void GatedSiluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
auto input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
scale, scale_inverse);
nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
}
void GatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
GatedSiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr,
output);
}
void GatedSiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
GatedSiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out,
output);
}
void DGatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *silu_input = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto silu_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
nvte_dswiglu(input_tensor.data(), silu_input_tensor.data(), output_tensor.data(), stream);
}
void DGatedSiluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *silu_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = desc.shape.to_vector();
auto silu_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto output_trans_shape = std::vector<size_t>{n * 2, m};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
nvte_dswiglu_cast_transpose(input_tensor.data(), silu_input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), stream);
} }
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
......
...@@ -43,14 +43,24 @@ struct Shape { ...@@ -43,14 +43,24 @@ struct Shape {
} }
}; };
enum class NVTE_Activation_Enum {
GELU,
GEGLU,
SILU,
SWIGLU,
};
size_t get_activation_len(NVTE_Activation_Enum act_enum);
struct CustomCallCommonDescriptor { struct CustomCallCommonDescriptor {
Shape shape; Shape shape;
DType in_dtype; DType in_dtype;
DType out_dtype; DType out_dtype;
size_t act_enum;
}; };
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype, pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype); DType out_dtype, size_t act_enum = 0);
struct CustomCallCommonWkDescriptor { struct CustomCallCommonWkDescriptor {
Shape shape; Shape shape;
...@@ -58,11 +68,13 @@ struct CustomCallCommonWkDescriptor { ...@@ -58,11 +68,13 @@ struct CustomCallCommonWkDescriptor {
DType in_dtype; DType in_dtype;
DType out_dtype; DType out_dtype;
DType wk_dtype; DType wk_dtype;
size_t act_enum;
}; };
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape, pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype, const std::vector<size_t> &wkshape,
DType out_dtype, DType wk_dtype); DType in_dtype, DType out_dtype, DType wk_dtype,
size_t act_enum = 0);
struct CustomCallNormDescriptor { struct CustomCallNormDescriptor {
size_t batch_size; size_t batch_size;
...@@ -140,17 +152,16 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o ...@@ -140,17 +152,16 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// TODO (Phuong): Templating these 9x2 rountines before adding ReGLU, QuickGeLU, Squared ReLu void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype); DType in_dtype, DType out_dtype);
void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
...@@ -159,31 +170,7 @@ pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hi ...@@ -159,31 +170,7 @@ pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hi
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void Silu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void SiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void GatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GatedSiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedSiluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
......
...@@ -955,7 +955,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -955,7 +955,6 @@ class LayerNormMLP(TransformerEngineBase):
normalize_acts = tuple(reversed(normalize_acts) normalize_acts = tuple(reversed(normalize_acts)
if normalize_acts[0] == 'linear' else normalize_acts) if normalize_acts[0] == 'linear' else normalize_acts)
is_gated = normalize_acts in gated_act_pool
is_act_implemented = normalize_acts in (gated_act_pool + act_pool) is_act_implemented = normalize_acts in (gated_act_pool + act_pool)
use_fused_layernorm_mlp = fuse_layernorm and is_act_implemented and\ use_fused_layernorm_mlp = fuse_layernorm and is_act_implemented and\
...@@ -1052,8 +1051,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1052,8 +1051,8 @@ class LayerNormMLP(TransformerEngineBase):
axes=self.bias_axes_2) axes=self.bias_axes_2)
bias_2 = bias_2.astype(self.dtype) bias_2 = bias_2.astype(self.dtype)
else: else:
bias_1 = jnp.empty(0, self.dtype) bias_1 = None
bias_2 = jnp.empty(0, self.dtype) bias_2 = None
out = fused_layernorm_fp8_mlp(y, out = fused_layernorm_fp8_mlp(y,
scale, scale,
...@@ -1134,7 +1133,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1134,7 +1133,6 @@ class LayerNormMLP(TransformerEngineBase):
x += jnp.reshape(bias_1, bias_1_shape) x += jnp.reshape(bias_1, bias_1_shape)
x = checkpoint_name(x, ffn1_ckpt_name) x = checkpoint_name(x, ffn1_ckpt_name)
activations = [] activations = []
if is_act_implemented: if is_act_implemented:
z = activation_lu(x, normalize_acts) z = activation_lu(x, normalize_acts)
...@@ -1144,8 +1142,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1144,8 +1142,8 @@ class LayerNormMLP(TransformerEngineBase):
x_i = _convert_to_activation_function(act_fn)(x[idx]) x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i) activations.append(x_i)
z = functools.reduce(operator.mul, activations) z = functools.reduce(operator.mul, activations)
if not is_gated: if num_activations == 1:
z = jnp.reshape(z, (*z.shape[:-2], -1)) z = jnp.reshape(z, (*z.shape[:-2], -1))
z = nn.Dropout(rate=self.intermediate_dropout_rate, z = nn.Dropout(rate=self.intermediate_dropout_rate,
broadcast_dims=self.intermediate_hidden_dropout_dims, broadcast_dims=self.intermediate_hidden_dropout_dims,
......
...@@ -11,14 +11,8 @@ import jax.numpy as jnp ...@@ -11,14 +11,8 @@ import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
from .cpp_extensions import cast_fp8, transpose, cast_transpose, dbias_cast_transpose from .cpp_extensions import cast_fp8, transpose, cast_transpose, dbias_cast_transpose
from .cpp_extensions import gelu from .cpp_extensions import act_lu, act_lu_fp8, dact_lu
from .cpp_extensions import gelu_fp8, dgelu, dgelu_dbias_cast_transpose from .cpp_extensions import dact_lu_dbias_cast_transpose, dgated_act_lu_cast_transpose
from .cpp_extensions import gated_gelu, gated_gelu_fp8
from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose
from .cpp_extensions import silu, silu_fp8
from .cpp_extensions import dsilu, dsilu_dbias_cast_transpose
from .cpp_extensions import gated_silu, gated_silu_fp8
from .cpp_extensions import dgated_silu, dgated_silu_cast_transpose
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize, dequantize from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize, dequantize
...@@ -26,44 +20,6 @@ from .layernorm import canonicalize_layernorm_type ...@@ -26,44 +20,6 @@ from .layernorm import canonicalize_layernorm_type
from .fp8 import FP8Helper, FP8MetaPackage from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes from .sharding import with_sharding_constraint_by_logical_axes
activation_dict = {
('gelu',): {
'fwd': gelu,
"bwd": dgelu
},
('gelu', 'linear'): {
'fwd': gated_gelu,
'bwd': dgated_gelu
},
('silu',): {
'fwd': silu,
"bwd": dsilu
},
('silu', 'linear'): {
'fwd': gated_silu,
'bwd': dgated_silu
}
}
activation_fp8_dict = {
('gelu',): {
'fwd': gelu_fp8,
'bwd': dgelu_dbias_cast_transpose
},
('gelu', 'linear'): {
'fwd': gated_gelu_fp8,
'bwd': dgated_gelu_cast_transpose
},
('silu',): {
'fwd': silu_fp8,
'bwd': dsilu_dbias_cast_transpose
},
('silu', 'linear'): {
'fwd': gated_silu_fp8,
'bwd': dgated_silu_cast_transpose
}
}
def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]): def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
""" """
...@@ -84,7 +40,7 @@ def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable ...@@ -84,7 +40,7 @@ def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable
def _activation_lu_fwd_rule(x, activation_type): def _activation_lu_fwd_rule(x, activation_type):
fwd_output = activation_dict[activation_type]["fwd"](x) fwd_output = act_lu(x, activation_type)
return fwd_output, (x,) return fwd_output, (x,)
...@@ -92,7 +48,7 @@ def _activation_lu_bwd_rule(activation_type, ctx, g): ...@@ -92,7 +48,7 @@ def _activation_lu_bwd_rule(activation_type, ctx, g):
x, = ctx x, = ctx
assert x.dtype == g.dtype assert x.dtype == g.dtype
dx = activation_dict[activation_type]["bwd"](g, x) dx = dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape) dx = jnp.reshape(dx, x.shape)
return (dx,) return (dx,)
...@@ -106,7 +62,7 @@ def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, sca ...@@ -106,7 +62,7 @@ def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, sca
""" """
Activation Unit Activation Unit
""" """
transpose_indices = (1, 2, 0) if len(activation_type) > 1 else (2, 0, 1) transpose_indices = (1, 2, 0)
dx_trans_no_use = jnp.empty([x.shape[i] for i in transpose_indices], dtype=x.dtype) dx_trans_no_use = jnp.empty([x.shape[i] for i in transpose_indices], dtype=x.dtype)
dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype) dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
...@@ -127,19 +83,15 @@ def _activation_lu_fp8(x: jnp.ndarray, dx_trans_no_use: jnp.ndarray, dbias_no_us ...@@ -127,19 +83,15 @@ def _activation_lu_fp8(x: jnp.ndarray, dx_trans_no_use: jnp.ndarray, dbias_no_us
return output return output
def _activation_lu_fp8_fwd_rule( def _activation_lu_fp8_fwd_rule(x,
x, dx_trans_no_use, # pylint: disable=unused-argument
dx_trans_no_use, # pylint: disable=unused-argument dbias_no_use, # pylint: disable=unused-argument
dbias_no_use, # pylint: disable=unused-argument amax,
amax, scale, scale_inv,
scale, fwd_dtype, bwd_dtype, # pylint: disable=unused-argument
scale_inv, activation_type):
fwd_dtype, activation_lu_out, _ = act_lu_fp8(x, amax, scale, scale_inv, fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument activation_type)
activation_type):
activation_lu_out, _ = activation_fp8_dict[activation_type]["fwd"](x, amax, scale, scale_inv,
fwd_dtype)
activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv) activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
ctx = (x, amax, scale, scale_inv) ctx = (x, amax, scale, scale_inv)
return activation_lu_out, ctx return activation_lu_out, ctx
...@@ -153,14 +105,14 @@ def _activation_lu_fp8_bwd_rule( ...@@ -153,14 +105,14 @@ def _activation_lu_fp8_bwd_rule(
g): g):
x, amax, scale, scale_inv = ctx x, amax, scale, scale_inv = ctx
activation_lu_fp8_bwd = activation_fp8_dict[activation_type]["bwd"] if len(activation_type) > 1: #gated, no bias
if len(activation_type) > 1: #gated, no bias
dactivation_lu, dactivation_lu_trans, amax_out = \ dactivation_lu, dactivation_lu_trans, amax_out = \
activation_lu_fp8_bwd(g, x, amax, scale, scale_inv, bwd_dtype, -1) dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv, bwd_dtype, -1, activation_type)
dbias = jnp.empty(x.shape[-1], x.dtype) dbias = jnp.empty(x.shape[-1], x.dtype)
else: else: #not gated, with bias
dactivation_lu, dactivation_lu_trans, dbias, amax_out = \ dactivation_lu, dactivation_lu_trans, dbias, amax_out = \
activation_lu_fp8_bwd(g, x, amax, scale, scale_inv, bwd_dtype, -1) dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, bwd_dtype,
-1, -2, activation_type)
dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv) dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv) dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out, scale, scale_inv) ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out, scale, scale_inv)
...@@ -262,7 +214,6 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ...@@ -262,7 +214,6 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
activation_type, activation_type,
use_bias): use_bias):
is_gated = len(activation_type) > 1
# x should be in shape of (batch..., hidden) # x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out) # Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
# Kernel_2 should be in shape of (Hidden_in, Hidden_out) # Kernel_2 should be in shape of (Hidden_in, Hidden_out)
...@@ -276,15 +227,9 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ...@@ -276,15 +227,9 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0] assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
assert kernel_1.shape[-1] == kernel_2.shape[0] assert kernel_1.shape[-1] == kernel_2.shape[0]
# Squeeze act axis
# (hidden_in, 1, hidden_out) -> (hidden_in, hidden_out)
if not is_gated:
kernel_1 = jnp.squeeze(kernel_1, axis=-2)
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \ maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv) FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv) fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
amax = FP8Helper.update_amax_history(amax) amax = FP8Helper.update_amax_history(amax)
...@@ -337,8 +282,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ...@@ -337,8 +282,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
(x_contracting_dims, (0,)), (x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
if use_bias: if use_bias:
bias_1_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1.shape bias_1_shape = bias_1.shape
dot_1_output += jnp.reshape(bias_1, bias_1_shape) bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
else:
bias_1_shape = None
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1) gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
...@@ -347,12 +295,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ...@@ -347,12 +295,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
activation_lu_out_scale = scale[gemm2_x_idx] activation_lu_out_scale = scale[gemm2_x_idx]
activation_lu_out_scale_inv = scale_inv[gemm2_x_idx] activation_lu_out_scale_inv = scale_inv[gemm2_x_idx]
activation_lu_fwd_fp8 = activation_fp8_dict[activation_type]["fwd"]
# (batch..., hidden_in) -> (batch..., hidden) # (batch..., hidden_in) -> (batch..., hidden)
casted_activation_lu_out, updated_activation_lu_amax = \ casted_activation_lu_out, updated_activation_lu_amax = \
activation_lu_fwd_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale, act_lu_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale,
activation_lu_out_scale_inv, fwd_dtype) activation_lu_out_scale_inv, fwd_dtype, activation_type)
casted_activation_lu_out = with_sharding_constraint_by_logical_axes( casted_activation_lu_out = with_sharding_constraint_by_logical_axes(
casted_activation_lu_out, dot_2_input_axes) casted_activation_lu_out, dot_2_input_axes)
...@@ -370,15 +317,18 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ...@@ -370,15 +317,18 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
if use_bias: if use_bias:
bias_2_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2.shape bias_2_shape = bias_2.shape
dot_2_output += jnp.reshape(bias_2, bias_2_shape) bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
else:
bias_2_shape = None
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1, ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1,
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax,
x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape, maybe_fp32_to_fm32) x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32)
return dot_2_output, ctx return dot_2_output, ctx
...@@ -403,8 +353,6 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -403,8 +353,6 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \ updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32 = ctx x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32 = ctx
is_gated = len(activation_type) > 1
gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1) gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)
grad_amax = amax[gemm2_grad_idx, 0:1] grad_amax = amax[gemm2_grad_idx, 0:1]
...@@ -413,7 +361,6 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -413,7 +361,6 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
# Since the sharding of outputs should be the same as dot_1's input # Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
if use_bias: if use_bias:
casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \ casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \
dbias_cast_transpose(grad, grad_amax, grad_scale, dbias_cast_transpose(grad, grad_amax, grad_scale,
...@@ -427,7 +374,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -427,7 +374,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
grad_scale_inv, bwd_dtype, grad_scale_inv, bwd_dtype,
static_axis_boundary=-1, static_axis_boundary=-1,
transpose_axis_boundary=-1) transpose_axis_boundary=-1)
dbias_2 = jnp.empty(bias_2_shape, grad.dtype) dbias_2 = None
casted_activation_lu_out_t = transpose(casted_activation_lu_out, casted_activation_lu_out_t = transpose(casted_activation_lu_out,
static_axis_boundary=-1, static_axis_boundary=-1,
...@@ -453,11 +400,9 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -453,11 +400,9 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dactivation_lu_scale = scale[gemm1_grad_idx] dactivation_lu_scale = scale[gemm1_grad_idx]
dactivation_lu_scale_inv = scale_inv[gemm1_grad_idx] dactivation_lu_scale_inv = scale_inv[gemm1_grad_idx]
dactivation_lu_cast_transpose = activation_fp8_dict[activation_type]["bwd"] if len(activation_type) > 1: # if gated
dactivation_lu = activation_dict[activation_type]["bwd"](dgrad_2, dot_1_output)
if is_gated:
if use_bias: if use_bias:
dactivation_lu = dact_lu(dgrad_2, dot_1_output, activation_type)
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \ casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
dbias_cast_transpose( dbias_cast_transpose(
dactivation_lu, dactivation_lu,
...@@ -470,19 +415,20 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -470,19 +415,20 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dbias_1 = jnp.reshape(dbias_1, bias_1_shape) dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else: else:
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \ casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
dactivation_lu_cast_transpose( dgated_act_lu_cast_transpose(
dgrad_2, dgrad_2,
dot_1_output, dot_1_output,
dactivation_lu_amax, dactivation_lu_amax,
dactivation_lu_scale, dactivation_lu_scale,
dactivation_lu_scale_inv, dactivation_lu_scale_inv,
bwd_dtype, bwd_dtype,
static_axis_boundary=-1) static_axis_boundary=-1,
dbias_1 = jnp.empty(bias_1_shape, bwd_dtype) activation_type=activation_type)
dbias_1 = None
else: else:
if use_bias: if use_bias:
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \ casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax=\
dactivation_lu_cast_transpose( dact_lu_dbias_cast_transpose(
dgrad_2, dgrad_2,
dot_1_output, dot_1_output,
dactivation_lu_amax, dactivation_lu_amax,
...@@ -490,9 +436,11 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -490,9 +436,11 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dactivation_lu_scale_inv, dactivation_lu_scale_inv,
bwd_dtype, bwd_dtype,
static_axis_boundary=-1, static_axis_boundary=-1,
transpose_axis_boundary=-1) transpose_axis_boundary=-2,
activation_type=activation_type)
dbias_1 = jnp.reshape(dbias_1, bias_1_shape) dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else: else:
dactivation_lu = dact_lu(dgrad_2, dot_1_output, activation_type)
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \ casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
cast_transpose( cast_transpose(
dactivation_lu, dactivation_lu,
...@@ -501,28 +449,20 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -501,28 +449,20 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dactivation_lu_scale_inv, dactivation_lu_scale_inv,
bwd_dtype, bwd_dtype,
static_axis_boundary=-1, static_axis_boundary=-1,
transpose_axis_boundary=-1) transpose_axis_boundary=-2)
dbias_1 = jnp.empty(bias_1_shape, bwd_dtype) dbias_1 = None
ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1) ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
# (hidden, batch...) x (hidden, batch...) # (hidden, batch...) x (hidden, batch...)
gemm1_x_scale_inv = scale_inv[gemm1_x_idx] gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
xt_batch_dims_2 = xt_batch_dims if not is_gated \ xt_batch_dims_2 = tuple(i + 1 for i in xt_batch_dims)
else tuple(i + 1 for i in xt_batch_dims)
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv, wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv,
dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2), dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# Expand act axis to match the shape with the given kernel_1
if not is_gated:
wgrad_1 = jnp.expand_dims(wgrad_1, axis=-2)
# (batch..., hidden_out) x (hidden_in, hidden_out) x_contracting_dims = ((min(x_contracting_dims),) + tuple(
if is_gated: i + 1 for i in x_contracting_dims), (1,2))
x_contracting_dims = ((min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims),
(1, 2))
else:
x_contracting_dims = (x_contracting_dims, (1,))
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx] kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv, dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv,
kernel_1_scale_inv, grad.dtype, x_contracting_dims, kernel_1_scale_inv, grad.dtype, x_contracting_dims,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment