Unverified Commit dac00019 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Unifying GeLU and GeGLU in LayerNorm MLP (#765)



* combined layernorm_geglu with layernorm_gelu into fused_layernorm
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fixes to pass all unit tests in test_custom_call_compute.py,
test_layer.py, and test_praxis_layer.py
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* cleaning and formatting
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* renaming based on reviewers suggestions
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* implemented partial fused layernorm
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* geglu + bias passed tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* added partial fused calculation for dbias_1
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* clean up
Co-authored-by: default avatarAlp Dener <adener@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
Co-authored-by: default avatarAlp Dener <adener@nvidia.com>
parent 07bf4acf
......@@ -4,6 +4,7 @@
import functools
import operator
from typing import Callable, Sequence, Union
import jax
import jax.numpy as jnp
......@@ -22,8 +23,7 @@ from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quanti
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp
from transformer_engine.jax.mlp import layernorm_gelu_fp8_mlp
from transformer_engine.jax.mlp import fused_layernorm_fp8_mlp
GEMM_CASES = [
(256, 256, 512),
......@@ -174,17 +174,32 @@ class TestFP8Dot:
assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
@pytest.mark.parametrize('m,n,k', [(256, 512, 128), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)])
def test_grad_ln_geglu_fp8_mlp(self, m, n, k):
@pytest.mark.parametrize('activation_type', [('gelu', ),
('gelu', 'linear')])
@pytest.mark.parametrize('use_bias', [True, False])
def test_grad_fused_layernorm_fp8_mlp(self, m, n, k,
activation_type: Sequence[Union[str, Callable]],
use_bias: bool):
""" N/a """
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 4)
activations = ('gelu', 'linear')
subkeys = jax.random.split(key, 6)
activation_dict = {
('gelu', ): jax.nn.gelu
}
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16)
k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
s = jax.random.normal(subkeys[3], (k,), jnp.bfloat16)
s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
else:
b1 = jax.random.normal(subkeys[3], (0,), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (0,), jnp.bfloat16)
init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
init_fp8_metas_amax = jnp.zeros(
......@@ -192,14 +207,16 @@ class TestFP8Dot:
init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
def primitive_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale,
def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
# x is input tensor, matrix 2d
# y, z are weights, matrix 2d
# out = (x * y) * z
# out = ((x * y) + w) * z + v
fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.mean(layernorm_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm"))
return jnp.mean(
fused_layernorm_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm",
activation_type = activation_type, use_bias = use_bias))
def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function."""
......@@ -211,115 +228,7 @@ class TestFP8Dot:
return fn_or_string
raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")
def ln_geglu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray) -> jnp.ndarray:
x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
ln_out = y * ln_scale
ln_out = jnp.asarray(ln_out, jnp.bfloat16)
fp8_gemm_1_pkg = FP8MetaPackage(1, fp8_maxs[:FP8Helper.NUM_META_PER_GEMM],
amax[:FP8Helper.NUM_META_PER_GEMM],
scale[:FP8Helper.NUM_META_PER_GEMM],
scale_inv[:FP8Helper.NUM_META_PER_GEMM])
linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,)))
x = jnp.split(linear_1_out, len(activations), axis=-2)
acts = []
for idx, act_fn in enumerate(activations):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = functools.reduce(operator.mul, acts)
x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)
fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
amax[FP8Helper.NUM_META_PER_GEMM:],
scale[FP8Helper.NUM_META_PER_GEMM:],
scale_inv[FP8Helper.NUM_META_PER_GEMM:])
output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,)))
return output
def ref_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv):
return jnp.mean(
ln_geglu_fp8_mlp_ref(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv))
value_n_grad_primitive_func = jit(value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7)))
value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7)))
ref_fp8_max = init_fp8_max
ref_fp8_metas_amax = init_fp8_metas_amax
ref_fp8_metas_scale = init_fp8_metas_scale
ref_fp8_metas_scale_inv = init_fp8_metas_scale_inv
pri_fp8_max = init_fp8_max
pri_fp8_metas_amax = init_fp8_metas_amax
pri_fp8_metas_scale = init_fp8_metas_scale
pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv
for _ in range(3):
ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_fp8_max,
ref_fp8_metas_amax, ref_fp8_metas_scale,
ref_fp8_metas_scale_inv) = value_n_grad_ref_func(
a, s, k1, k2, ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale,
ref_fp8_metas_scale_inv)
for _ in range(3):
primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
primitive_k2_grad, pri_fp8_max, pri_fp8_metas_amax, pri_fp8_metas_scale,
pri_fp8_metas_scale_inv) = value_n_grad_primitive_func(
a, s, k1, k2, pri_fp8_max, pri_fp8_metas_amax, pri_fp8_metas_scale,
pri_fp8_metas_scale_inv)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
jnp.asarray(ref_a_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)])
def test_grad_ln_gelu_fp8_mlp(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 6)
activations = ('gelu',)
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16)
k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
b1 = jax.random.normal(subkeys[3], (len(activations), n), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
init_fp8_metas_amax = jnp.zeros(
(FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), jnp.float32)
init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
# x is input tensor, matrix 2d
# y, z are weights, matrix 2d
# out = ((x * y) + w) * z + v
fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.mean(
layernorm_gelu_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm"))
def ln_gelu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray,
fp8_maxs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray) -> jnp.ndarray:
......@@ -336,10 +245,20 @@ class TestFP8Dot:
scale_inv[:FP8Helper.NUM_META_PER_GEMM])
linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,)))
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape)
if use_bias:
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape)
if 'linear' in activation_type:
x = jnp.split(linear_1_out, len(activation_type), axis=-2)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = functools.reduce(operator.mul, acts)
else:
x = activation_dict[activation_type](linear_1_out)
x = jax.nn.gelu(linear_1_out)
x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)
fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
......@@ -348,15 +267,16 @@ class TestFP8Dot:
scale_inv[FP8Helper.NUM_META_PER_GEMM:])
output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,)))
bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
output += jnp.reshape(bias_2, bias_2_shape)
if use_bias:
bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
output += jnp.reshape(bias_2, bias_2_shape)
return output
def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
return jnp.mean(
ln_gelu_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
layernorm_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv))
value_n_grad_primitive_func = jit(
......@@ -373,12 +293,13 @@ class TestFP8Dot:
pri_fp8_metas_scale = init_fp8_metas_scale
pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv
# Convert str to index as str is not a valid type for JAX JIT
for _ in range(3):
ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad,
ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale,
ref_fp8_metas_scale_inv) = value_n_grad_ref_func(
a, s, k1, k2, b1, b2, ref_fp8_max, ref_fp8_metas_amax,
ref_fp8_metas_scale, ref_fp8_metas_scale_inv)
ref_fp8_metas_scale, ref_fp8_metas_scale_inv)
for _ in range(3):
primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
......@@ -401,12 +322,14 @@ class TestFP8Dot:
assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=jnp.bfloat16)
assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
jnp.asarray(ref_b2_grad, np.float32),
dtype=jnp.bfloat16)
if use_bias:
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=jnp.bfloat16)
assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
jnp.asarray(ref_b2_grad, np.float32),
dtype=jnp.bfloat16)
@pytest.fixture(name="random_inputs")
......
......@@ -529,10 +529,11 @@ void cast_transpose_dbias(const Tensor &input,
Tensor *dbias,
Tensor *workspace,
cudaStream_t stream) {
CheckInputTensor(input, "cast_transpose_dbias_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
// TODO
// CheckInputTensor(input, "cast_transpose_dbias_input");
// CheckOutputTensor(*cast_output, "cast_output");
// CheckOutputTensor(*transposed_output, "transposed_output");
// CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
......
......@@ -4334,6 +4334,231 @@ def dgelu_dbias_cast_transpose(
transpose_axis_boundary=transpose_axis_boundary)
class DBiasCastTransposePrimitive(BasePrimitive):
"""
DBias Cast Transpose Primitive
"""
name = "te_dbias_cast_transpose"
multiple_results = True
# out_dtype, static_axis_boundary, transpose_axis_boundary
impl_static_args = (4, 5, 6)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(dz_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
static_axis_boundary, transpose_axis_boundary):
"""
te_dbias_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
gi_hidden_size = dz_aval.shape[-1]
t_shape = _multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary)
out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
if dz_aval.shape[-2] == 2:
gi_hidden_size *= 2
dbias_shape = (*dz_aval.shape[:static_axis_boundary + 1], gi_hidden_size)
dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
wkspace_info, = transformer_engine_jax.get_dbias_ct_workspace_sizes(
dz_aval.size // gi_hidden_size,
gi_hidden_size,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype)
)
wkspace_aval = dz_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
return out, t_out, dbias, updated_amax_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
te_dbias_cast_transpose_p outer abstract
"""
out, t_out, dbias, updated_amax_aval, _ = \
DBiasCastTransposePrimitive.abstract(*args, **kwargs)
return out, t_out, dbias, updated_amax_aval
@staticmethod
def lowering(ctx, dz, amax, scale, scale_inv, *, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
te_dbias_cast_transpose_p lowering rules
"""
dz_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
ir_hidden_szie = ir_dz_shape[-1]
if dz_aval.shape[-2] == 2:
batch_szie = reduce(operator.mul, ir_dz_shape[:-2])
ir_hidden_szie *= 2
else:
batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
contracted_dz_shape = (batch_szie, ir_hidden_szie)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_dz_shape = _multidim_transpose(ir_dz_shape, static_axis_boundary,
transpose_axis_boundary)
dbias_shape = (*ir_dz_shape[:static_axis_boundary + 1], ir_hidden_szie)
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(ir_dz_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype),
ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
]
operands = [dz, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_dz_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype))
out = custom_caller(DBiasCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={1: 3})
return out
@staticmethod
def impl(dz, amax, scale, scale_inv, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
to describe implementation
"""
assert DBiasCastTransposePrimitive.inner_primitive is not None
out, t_out, dbias, updated_amax, _ = DBiasCastTransposePrimitive.inner_primitive.bind(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
return out, t_out, dbias, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
to describe batch rules for vmap
"""
del static_axis_boundary
_check_valid_batch_dims(batch_dims)
assert DBiasCastTransposePrimitive.outer_primitive is not None
dz, amax, scale, scale_inv = batched_args
dz_bdim, _, amax_bdim, _, _ = batch_dims
# Minus batch dim.
transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, dz.ndim - 1)
transpose_axis_boundary += 1 # Plus batch dim
out_bdims = dz_bdim, dz_bdim, dz_bdim, amax_bdim
return DBiasCastTransposePrimitive.outer_primitive.bind(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=dz_bdim,
transpose_axis_boundary=transpose_axis_boundary), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh,
arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)
@staticmethod
def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding,
amax_sharding)
def sharded_impl(dz, amax, scale, scale_inv):
local_out, local_t_out, local_dbias, local_amax = DBiasCastTransposePrimitive.impl(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_out, local_t_out, global_dbias, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DBiasCastTransposePrimitive)
def dbias_cast_transpose(
dz: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType,
static_axis_boundary: int,
transpose_axis_boundary: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose dbias partial fusion wrapper
Return FP8(inputs), dbias
"""
if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes
return DBiasCastTransposePrimitive.outer_primitive.bind(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
class GatedGeluFp8Primitive(BasePrimitive):
"""
Gated Gelu FP8 Primitive
......
......@@ -29,6 +29,7 @@ pybind11::dict Registrations() {
dict["te_gelu_fp8"] = EncapsulateFunction(GeluFP8);
dict["te_dgelu"] = EncapsulateFunction(DGelu);
dict["te_dgelu_dbias_cast_transpose"] = EncapsulateFunction(DGeluDBiasCastTranspose);
dict["te_dbias_cast_transpose"] = EncapsulateFunction(DBiasCastTranspose);
dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu);
dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8);
dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu);
......@@ -66,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes);
m.def("get_dbias_ct_workspace_sizes", &GetDBiasCastTransposeWorkspaceSizes);
m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes);
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
......
......@@ -301,6 +301,69 @@ void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *op
dbias_tensor.data(), workspace.data(), stream);
}
// HERE
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
TensorWrapper dummy_workspace;
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
auto *output_trans = buffers[5];
auto *dbias = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
void *workspace_ptr = buffers[8];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
workspace.data(), stream);
}
void GatedGeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
auto input_shape = std::vector<size_t>{m, n * 2};
......
......@@ -152,6 +152,12 @@ pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size
void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
......@@ -22,8 +22,7 @@ from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage
from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot
from ..mlp import layernorm_geglu_fp8_mlp, geglu
from ..mlp import layernorm_gelu_fp8_mlp, gelu
from ..mlp import fused_layernorm_fp8_mlp, activation_lu
from ..softmax import is_softmax_kernel_available
from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes
......@@ -944,35 +943,22 @@ class LayerNormMLP(TransformerEngineBase):
fuse_layernorm = FP8Helper.is_fp8_enabled(
) and not self.return_layernorm_output and self.enable_layernorm
def is_geglu(acts):
geglu_act_pool = [('gelu', 'linear'), ('linear', 'gelu')]
normalize_acts = []
for act in acts:
if not isinstance(act, str):
return False
normalize_acts.append(act.lower())
return tuple(normalize_acts) in geglu_act_pool
def is_gelu(acts):
geglu_act_pool = [('gelu',)]
normalize_acts = []
for act in acts:
if not isinstance(act, str):
return False
normalize_acts.append(act.lower())
return tuple(normalize_acts) in geglu_act_pool
use_fused_ln_geglu_mlp = fuse_layernorm \
and (not self.use_bias) and is_geglu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3) \
and not self.enable_low_rank_adaptation
use_fused_ln_gelu_mlp = fuse_layernorm \
and self.use_bias and is_gelu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3) \
and not self.enable_low_rank_adaptation
# Make sure each tuple is sorted in alphabet order
gated_act_pool = [('gelu', 'linear')]
#('linear', 'silu')] coming
act_pool = [('gelu',)]
#('silu',)] coming
normalize_acts = []
for act in self.activations:
if not isinstance(act, str):
return False
normalize_acts.append(act.lower())
normalize_acts = tuple(sorted(normalize_acts))
is_gated = normalize_acts in gated_act_pool
is_act_implemented = normalize_acts in (gated_act_pool + act_pool)
use_fused_layernorm_mlp = fuse_layernorm and is_act_implemented and\
self.intermediate_dropout_rate < 1e-3
# LayerNorm
if self.enable_layernorm:
......@@ -1045,38 +1031,26 @@ class LayerNormMLP(TransformerEngineBase):
ffn1_ckpt_name = 'ffn1'
ffn2_ckpt_name = 'ffn2'
if use_fused_ln_geglu_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
out = layernorm_geglu_fp8_mlp(y,
scale,
ln_bias, [kernel_1, kernel_2],
fp8_meta_package,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name)
elif use_fused_ln_gelu_mlp:
if use_fused_layernorm_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
bias_1_shape = intermediate_dim if self.use_bias else 0
bias_1 = nn_partitioning.param_with_axes('wi_bias',
self.bias_init,
intermediate_dim,
bias_1_shape,
jnp.float32,
axes=self.bias_axes_1)
bias_1 = bias_1.astype(self.dtype)
bias_2_shape = (hidden_size,) if self.use_bias else (0,)
bias_2 = nn_partitioning.param_with_axes('wo_bias',
self.bias_init, (hidden_size,),
self.bias_init,
bias_2_shape,
jnp.float32,
axes=self.bias_axes_2)
bias_2 = bias_2.astype(self.dtype)
out = layernorm_gelu_fp8_mlp(y,
out = fused_layernorm_fp8_mlp(y,
scale,
ln_bias, [kernel_1, kernel_2], [bias_1, bias_2],
fp8_meta_package,
......@@ -1087,9 +1061,10 @@ class LayerNormMLP(TransformerEngineBase):
dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name)
ffn2_ckpt_name=ffn2_ckpt_name,
activation_type = normalize_acts,
use_bias = self.use_bias)
else: # not use_fused_ln_geglu_mlp
# DenseGeneral 1
gemm1_fp8_meta_package = None if fp8_meta_package is None \
else fp8_meta_package.get_package_by_gemm_idx(0)
......@@ -1142,31 +1117,29 @@ class LayerNormMLP(TransformerEngineBase):
x += _apply_low_rank_adaptation(y, axis, intermediate_dim, wi_lora_a_kernel,
wi_lora_b_kernel, self.low_rank_adaptation_alpha)
bias = None
bias_1 = None
if self.use_bias:
bias = nn_partitioning.param_with_axes('wi_bias',
bias_1 = nn_partitioning.param_with_axes('wi_bias',
self.bias_init,
intermediate_dim,
jnp.float32,
axes=self.bias_axes_1)
bias = bias.astype(self.dtype)
bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
x += jnp.reshape(bias, bias_shape)
bias_1 = bias_1.astype(self.dtype)
bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
x += jnp.reshape(bias_1, bias_1_shape)
x = checkpoint_name(x, ffn1_ckpt_name)
activations = []
if is_geglu(self.activations):
z = geglu(x)
elif is_gelu(self.activations):
z = gelu(x)
z = jnp.reshape(z, (*z.shape[:-2], -1))
if is_act_implemented:
z = activation_lu(x, normalize_acts)
else:
x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(self.activations):
x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i)
z = functools.reduce(operator.mul, activations)
if not is_gated:
z = jnp.reshape(z, (*z.shape[:-2], -1))
z = nn.Dropout(rate=self.intermediate_dropout_rate,
......@@ -1207,14 +1180,14 @@ class LayerNormMLP(TransformerEngineBase):
out += _apply_low_rank_adaptation(z, axis, hidden_size_tuple, wo_lora_a_kernel,
wo_lora_b_kernel, self.low_rank_adaptation_alpha)
bias = None
bias_2 = None
if self.use_bias:
bias = nn_partitioning.param_with_axes('wo_bias',
bias_2 = nn_partitioning.param_with_axes('wo_bias',
self.bias_init, (hidden_size,),
jnp.float32,
axes=self.bias_axes_2)
bias = bias.astype(self.dtype)
out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,))
bias_2 = bias_2.astype(self.dtype)
out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
out = checkpoint_name(out, ffn2_ckpt_name)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment