Unverified Commit dac00019 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Unifying GeLU and GeGLU in LayerNorm MLP (#765)



* combined layernorm_geglu with layernorm_gelu into fused_layernorm
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fixes to pass all unit tests in test_custom_call_compute.py,
test_layer.py, and test_praxis_layer.py
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* cleaning and formatting
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* renaming based on reviewers suggestions
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* implemented partial fused layernorm
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* geglu + bias passed tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* added partial fused calculation for dbias_1
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* clean up
Co-authored-by: default avatarAlp Dener <adener@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
Co-authored-by: default avatarAlp Dener <adener@nvidia.com>
parent 07bf4acf
......@@ -4,6 +4,7 @@
import functools
import operator
from typing import Callable, Sequence, Union
import jax
import jax.numpy as jnp
......@@ -22,8 +23,7 @@ from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quanti
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp
from transformer_engine.jax.mlp import layernorm_gelu_fp8_mlp
from transformer_engine.jax.mlp import fused_layernorm_fp8_mlp
GEMM_CASES = [
(256, 256, 512),
......@@ -174,17 +174,32 @@ class TestFP8Dot:
assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
@pytest.mark.parametrize('m,n,k', [(256, 512, 128), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)])
def test_grad_ln_geglu_fp8_mlp(self, m, n, k):
@pytest.mark.parametrize('activation_type', [('gelu', ),
('gelu', 'linear')])
@pytest.mark.parametrize('use_bias', [True, False])
def test_grad_fused_layernorm_fp8_mlp(self, m, n, k,
activation_type: Sequence[Union[str, Callable]],
use_bias: bool):
""" N/a """
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 4)
activations = ('gelu', 'linear')
subkeys = jax.random.split(key, 6)
activation_dict = {
('gelu', ): jax.nn.gelu
}
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16)
k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
s = jax.random.normal(subkeys[3], (k,), jnp.bfloat16)
s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
else:
b1 = jax.random.normal(subkeys[3], (0,), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (0,), jnp.bfloat16)
init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
init_fp8_metas_amax = jnp.zeros(
......@@ -192,14 +207,16 @@ class TestFP8Dot:
init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
def primitive_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale,
def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
# x is input tensor, matrix 2d
# y, z are weights, matrix 2d
# out = (x * y) * z
# out = ((x * y) + w) * z + v
fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.mean(layernorm_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm"))
return jnp.mean(
fused_layernorm_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm",
activation_type = activation_type, use_bias = use_bias))
def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function."""
......@@ -211,115 +228,7 @@ class TestFP8Dot:
return fn_or_string
raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")
def ln_geglu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray) -> jnp.ndarray:
x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
ln_out = y * ln_scale
ln_out = jnp.asarray(ln_out, jnp.bfloat16)
fp8_gemm_1_pkg = FP8MetaPackage(1, fp8_maxs[:FP8Helper.NUM_META_PER_GEMM],
amax[:FP8Helper.NUM_META_PER_GEMM],
scale[:FP8Helper.NUM_META_PER_GEMM],
scale_inv[:FP8Helper.NUM_META_PER_GEMM])
linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,)))
x = jnp.split(linear_1_out, len(activations), axis=-2)
acts = []
for idx, act_fn in enumerate(activations):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = functools.reduce(operator.mul, acts)
x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)
fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
amax[FP8Helper.NUM_META_PER_GEMM:],
scale[FP8Helper.NUM_META_PER_GEMM:],
scale_inv[FP8Helper.NUM_META_PER_GEMM:])
output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,)))
return output
def ref_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv):
return jnp.mean(
ln_geglu_fp8_mlp_ref(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv))
value_n_grad_primitive_func = jit(value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7)))
value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7)))
ref_fp8_max = init_fp8_max
ref_fp8_metas_amax = init_fp8_metas_amax
ref_fp8_metas_scale = init_fp8_metas_scale
ref_fp8_metas_scale_inv = init_fp8_metas_scale_inv
pri_fp8_max = init_fp8_max
pri_fp8_metas_amax = init_fp8_metas_amax
pri_fp8_metas_scale = init_fp8_metas_scale
pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv
for _ in range(3):
ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_fp8_max,
ref_fp8_metas_amax, ref_fp8_metas_scale,
ref_fp8_metas_scale_inv) = value_n_grad_ref_func(
a, s, k1, k2, ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale,
ref_fp8_metas_scale_inv)
for _ in range(3):
primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
primitive_k2_grad, pri_fp8_max, pri_fp8_metas_amax, pri_fp8_metas_scale,
pri_fp8_metas_scale_inv) = value_n_grad_primitive_func(
a, s, k1, k2, pri_fp8_max, pri_fp8_metas_amax, pri_fp8_metas_scale,
pri_fp8_metas_scale_inv)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
jnp.asarray(ref_a_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)])
def test_grad_ln_gelu_fp8_mlp(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 6)
activations = ('gelu',)
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16)
k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
b1 = jax.random.normal(subkeys[3], (len(activations), n), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
init_fp8_metas_amax = jnp.zeros(
(FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), jnp.float32)
init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
# x is input tensor, matrix 2d
# y, z are weights, matrix 2d
# out = ((x * y) + w) * z + v
fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.mean(
layernorm_gelu_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm"))
def ln_gelu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray,
fp8_maxs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray) -> jnp.ndarray:
......@@ -336,10 +245,20 @@ class TestFP8Dot:
scale_inv[:FP8Helper.NUM_META_PER_GEMM])
linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,)))
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape)
if use_bias:
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape)
if 'linear' in activation_type:
x = jnp.split(linear_1_out, len(activation_type), axis=-2)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = functools.reduce(operator.mul, acts)
else:
x = activation_dict[activation_type](linear_1_out)
x = jax.nn.gelu(linear_1_out)
x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)
fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
......@@ -348,15 +267,16 @@ class TestFP8Dot:
scale_inv[FP8Helper.NUM_META_PER_GEMM:])
output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,)))
bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
output += jnp.reshape(bias_2, bias_2_shape)
if use_bias:
bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
output += jnp.reshape(bias_2, bias_2_shape)
return output
def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
return jnp.mean(
ln_gelu_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
layernorm_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv))
value_n_grad_primitive_func = jit(
......@@ -373,12 +293,13 @@ class TestFP8Dot:
pri_fp8_metas_scale = init_fp8_metas_scale
pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv
# Convert str to index as str is not a valid type for JAX JIT
for _ in range(3):
ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad,
ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale,
ref_fp8_metas_scale_inv) = value_n_grad_ref_func(
a, s, k1, k2, b1, b2, ref_fp8_max, ref_fp8_metas_amax,
ref_fp8_metas_scale, ref_fp8_metas_scale_inv)
ref_fp8_metas_scale, ref_fp8_metas_scale_inv)
for _ in range(3):
primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
......@@ -401,12 +322,14 @@ class TestFP8Dot:
assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=jnp.bfloat16)
assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
jnp.asarray(ref_b2_grad, np.float32),
dtype=jnp.bfloat16)
if use_bias:
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=jnp.bfloat16)
assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
jnp.asarray(ref_b2_grad, np.float32),
dtype=jnp.bfloat16)
@pytest.fixture(name="random_inputs")
......
......@@ -529,10 +529,11 @@ void cast_transpose_dbias(const Tensor &input,
Tensor *dbias,
Tensor *workspace,
cudaStream_t stream) {
CheckInputTensor(input, "cast_transpose_dbias_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
// TODO
// CheckInputTensor(input, "cast_transpose_dbias_input");
// CheckOutputTensor(*cast_output, "cast_output");
// CheckOutputTensor(*transposed_output, "transposed_output");
// CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
......
......@@ -4334,6 +4334,231 @@ def dgelu_dbias_cast_transpose(
transpose_axis_boundary=transpose_axis_boundary)
class DBiasCastTransposePrimitive(BasePrimitive):
"""
DBias Cast Transpose Primitive
"""
name = "te_dbias_cast_transpose"
multiple_results = True
# out_dtype, static_axis_boundary, transpose_axis_boundary
impl_static_args = (4, 5, 6)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(dz_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
static_axis_boundary, transpose_axis_boundary):
"""
te_dbias_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
gi_hidden_size = dz_aval.shape[-1]
t_shape = _multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary)
out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
if dz_aval.shape[-2] == 2:
gi_hidden_size *= 2
dbias_shape = (*dz_aval.shape[:static_axis_boundary + 1], gi_hidden_size)
dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
wkspace_info, = transformer_engine_jax.get_dbias_ct_workspace_sizes(
dz_aval.size // gi_hidden_size,
gi_hidden_size,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype)
)
wkspace_aval = dz_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
return out, t_out, dbias, updated_amax_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
te_dbias_cast_transpose_p outer abstract
"""
out, t_out, dbias, updated_amax_aval, _ = \
DBiasCastTransposePrimitive.abstract(*args, **kwargs)
return out, t_out, dbias, updated_amax_aval
@staticmethod
def lowering(ctx, dz, amax, scale, scale_inv, *, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
te_dbias_cast_transpose_p lowering rules
"""
dz_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
ir_hidden_szie = ir_dz_shape[-1]
if dz_aval.shape[-2] == 2:
batch_szie = reduce(operator.mul, ir_dz_shape[:-2])
ir_hidden_szie *= 2
else:
batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
contracted_dz_shape = (batch_szie, ir_hidden_szie)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_dz_shape = _multidim_transpose(ir_dz_shape, static_axis_boundary,
transpose_axis_boundary)
dbias_shape = (*ir_dz_shape[:static_axis_boundary + 1], ir_hidden_szie)
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(ir_dz_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype),
ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
]
operands = [dz, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_dz_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype))
out = custom_caller(DBiasCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={1: 3})
return out
@staticmethod
def impl(dz, amax, scale, scale_inv, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
to describe implementation
"""
assert DBiasCastTransposePrimitive.inner_primitive is not None
out, t_out, dbias, updated_amax, _ = DBiasCastTransposePrimitive.inner_primitive.bind(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
return out, t_out, dbias, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
to describe batch rules for vmap
"""
del static_axis_boundary
_check_valid_batch_dims(batch_dims)
assert DBiasCastTransposePrimitive.outer_primitive is not None
dz, amax, scale, scale_inv = batched_args
dz_bdim, _, amax_bdim, _, _ = batch_dims
# Minus batch dim.
transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, dz.ndim - 1)
transpose_axis_boundary += 1 # Plus batch dim
out_bdims = dz_bdim, dz_bdim, dz_bdim, amax_bdim
return DBiasCastTransposePrimitive.outer_primitive.bind(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=dz_bdim,
transpose_axis_boundary=transpose_axis_boundary), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh,
arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)
@staticmethod
def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding,
amax_sharding)
def sharded_impl(dz, amax, scale, scale_inv):
local_out, local_t_out, local_dbias, local_amax = DBiasCastTransposePrimitive.impl(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_out, local_t_out, global_dbias, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DBiasCastTransposePrimitive)
def dbias_cast_transpose(
dz: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType,
static_axis_boundary: int,
transpose_axis_boundary: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose dbias partial fusion wrapper
Return FP8(inputs), dbias
"""
if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes
return DBiasCastTransposePrimitive.outer_primitive.bind(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
class GatedGeluFp8Primitive(BasePrimitive):
"""
Gated Gelu FP8 Primitive
......
......@@ -29,6 +29,7 @@ pybind11::dict Registrations() {
dict["te_gelu_fp8"] = EncapsulateFunction(GeluFP8);
dict["te_dgelu"] = EncapsulateFunction(DGelu);
dict["te_dgelu_dbias_cast_transpose"] = EncapsulateFunction(DGeluDBiasCastTranspose);
dict["te_dbias_cast_transpose"] = EncapsulateFunction(DBiasCastTranspose);
dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu);
dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8);
dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu);
......@@ -66,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes);
m.def("get_dbias_ct_workspace_sizes", &GetDBiasCastTransposeWorkspaceSizes);
m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes);
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
......
......@@ -301,6 +301,69 @@ void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *op
dbias_tensor.data(), workspace.data(), stream);
}
// HERE
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
TensorWrapper dummy_workspace;
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
auto *output_trans = buffers[5];
auto *dbias = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
void *workspace_ptr = buffers[8];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
workspace.data(), stream);
}
void GatedGeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
auto input_shape = std::vector<size_t>{m, n * 2};
......
......@@ -152,6 +152,12 @@ pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size
void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
......@@ -22,8 +22,7 @@ from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage
from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot
from ..mlp import layernorm_geglu_fp8_mlp, geglu
from ..mlp import layernorm_gelu_fp8_mlp, gelu
from ..mlp import fused_layernorm_fp8_mlp, activation_lu
from ..softmax import is_softmax_kernel_available
from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes
......@@ -944,35 +943,22 @@ class LayerNormMLP(TransformerEngineBase):
fuse_layernorm = FP8Helper.is_fp8_enabled(
) and not self.return_layernorm_output and self.enable_layernorm
def is_geglu(acts):
geglu_act_pool = [('gelu', 'linear'), ('linear', 'gelu')]
normalize_acts = []
for act in acts:
if not isinstance(act, str):
return False
normalize_acts.append(act.lower())
return tuple(normalize_acts) in geglu_act_pool
def is_gelu(acts):
geglu_act_pool = [('gelu',)]
normalize_acts = []
for act in acts:
if not isinstance(act, str):
return False
normalize_acts.append(act.lower())
return tuple(normalize_acts) in geglu_act_pool
use_fused_ln_geglu_mlp = fuse_layernorm \
and (not self.use_bias) and is_geglu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3) \
and not self.enable_low_rank_adaptation
use_fused_ln_gelu_mlp = fuse_layernorm \
and self.use_bias and is_gelu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3) \
and not self.enable_low_rank_adaptation
# Make sure each tuple is sorted in alphabet order
gated_act_pool = [('gelu', 'linear')]
#('linear', 'silu')] coming
act_pool = [('gelu',)]
#('silu',)] coming
normalize_acts = []
for act in self.activations:
if not isinstance(act, str):
return False
normalize_acts.append(act.lower())
normalize_acts = tuple(sorted(normalize_acts))
is_gated = normalize_acts in gated_act_pool
is_act_implemented = normalize_acts in (gated_act_pool + act_pool)
use_fused_layernorm_mlp = fuse_layernorm and is_act_implemented and\
self.intermediate_dropout_rate < 1e-3
# LayerNorm
if self.enable_layernorm:
......@@ -1045,38 +1031,26 @@ class LayerNormMLP(TransformerEngineBase):
ffn1_ckpt_name = 'ffn1'
ffn2_ckpt_name = 'ffn2'
if use_fused_ln_geglu_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
out = layernorm_geglu_fp8_mlp(y,
scale,
ln_bias, [kernel_1, kernel_2],
fp8_meta_package,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name)
elif use_fused_ln_gelu_mlp:
if use_fused_layernorm_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
bias_1_shape = intermediate_dim if self.use_bias else 0
bias_1 = nn_partitioning.param_with_axes('wi_bias',
self.bias_init,
intermediate_dim,
bias_1_shape,
jnp.float32,
axes=self.bias_axes_1)
bias_1 = bias_1.astype(self.dtype)
bias_2_shape = (hidden_size,) if self.use_bias else (0,)
bias_2 = nn_partitioning.param_with_axes('wo_bias',
self.bias_init, (hidden_size,),
self.bias_init,
bias_2_shape,
jnp.float32,
axes=self.bias_axes_2)
bias_2 = bias_2.astype(self.dtype)
out = layernorm_gelu_fp8_mlp(y,
out = fused_layernorm_fp8_mlp(y,
scale,
ln_bias, [kernel_1, kernel_2], [bias_1, bias_2],
fp8_meta_package,
......@@ -1087,9 +1061,10 @@ class LayerNormMLP(TransformerEngineBase):
dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name)
ffn2_ckpt_name=ffn2_ckpt_name,
activation_type = normalize_acts,
use_bias = self.use_bias)
else: # not use_fused_ln_geglu_mlp
# DenseGeneral 1
gemm1_fp8_meta_package = None if fp8_meta_package is None \
else fp8_meta_package.get_package_by_gemm_idx(0)
......@@ -1142,31 +1117,29 @@ class LayerNormMLP(TransformerEngineBase):
x += _apply_low_rank_adaptation(y, axis, intermediate_dim, wi_lora_a_kernel,
wi_lora_b_kernel, self.low_rank_adaptation_alpha)
bias = None
bias_1 = None
if self.use_bias:
bias = nn_partitioning.param_with_axes('wi_bias',
bias_1 = nn_partitioning.param_with_axes('wi_bias',
self.bias_init,
intermediate_dim,
jnp.float32,
axes=self.bias_axes_1)
bias = bias.astype(self.dtype)
bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
x += jnp.reshape(bias, bias_shape)
bias_1 = bias_1.astype(self.dtype)
bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
x += jnp.reshape(bias_1, bias_1_shape)
x = checkpoint_name(x, ffn1_ckpt_name)
activations = []
if is_geglu(self.activations):
z = geglu(x)
elif is_gelu(self.activations):
z = gelu(x)
z = jnp.reshape(z, (*z.shape[:-2], -1))
if is_act_implemented:
z = activation_lu(x, normalize_acts)
else:
x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(self.activations):
x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i)
z = functools.reduce(operator.mul, activations)
if not is_gated:
z = jnp.reshape(z, (*z.shape[:-2], -1))
z = nn.Dropout(rate=self.intermediate_dropout_rate,
......@@ -1207,14 +1180,14 @@ class LayerNormMLP(TransformerEngineBase):
out += _apply_low_rank_adaptation(z, axis, hidden_size_tuple, wo_lora_a_kernel,
wo_lora_b_kernel, self.low_rank_adaptation_alpha)
bias = None
bias_2 = None
if self.use_bias:
bias = nn_partitioning.param_with_axes('wo_bias',
bias_2 = nn_partitioning.param_with_axes('wo_bias',
self.bias_init, (hidden_size,),
jnp.float32,
axes=self.bias_axes_2)
bias = bias.astype(self.dtype)
out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,))
bias_2 = bias_2.astype(self.dtype)
out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
out = checkpoint_name(out, ffn2_ckpt_name)
......
......@@ -3,15 +3,15 @@
# See LICENSE for license information.
"""JAX MLP modules"""
from typing import List, Tuple
from typing import List, Tuple, Sequence, Union, Callable
from functools import partial
import jax
import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
from .cpp_extensions import cast_fp8, transpose, cast_transpose
from .cpp_extensions import gelu as te_gelu
from .cpp_extensions import cast_fp8, transpose, cast_transpose, dbias_cast_transpose
from .cpp_extensions import gelu
from .cpp_extensions import gelu_fp8, dgelu, dgelu_dbias_cast_transpose
from .cpp_extensions import gated_gelu, gated_gelu_fp8
from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose
......@@ -23,369 +23,56 @@ from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes
def gelu(x: jnp.ndarray):
"""
Gelu
"""
output = _gelu(x)
return output
@partial(jax.custom_vjp)
def _gelu(x: jnp.ndarray):
geglu_output, _ = _gelu_fwd_rule(x)
return geglu_output
def _gelu_fwd_rule(x):
geglu_output = te_gelu(x)
return geglu_output, (x,)
def _gelu_bwd_rule(ctx, g):
x, = ctx
assert x.dtype == g.dtype
dx = dgelu(g, x)
dx = jnp.reshape(dx, x.shape)
return (dx,)
activation_dict = {
('gelu',): {'fwd': gelu,
"bwd": dgelu},
('gelu', 'linear'): {'fwd': gated_gelu,
'bwd': dgated_gelu}
}
activation_fp8_dict = {
('gelu',): {'fwd': gelu_fp8,
'bwd': dgelu_dbias_cast_transpose},
('gelu', 'linear'): {'fwd': gated_gelu_fp8,
'bwd': dgated_gelu_cast_transpose}
}
_gelu.defvjp(_gelu_fwd_rule, _gelu_bwd_rule)
def geglu(x: jnp.ndarray):
def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
"""
Gated gelu
Activation Unit
"""
assert x.shape[-2] == 2 # Linear + GeLU
output = _geglu(x)
if len(activation_type) > 1:
assert x.shape[-2] == 2 # Linear + GeLU
output = _activation_lu(x, activation_type)
return output
@partial(jax.custom_vjp)
def _geglu(x: jnp.ndarray):
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
geglu_output, _ = _geglu_fwd_rule(x)
_output, _ = _activation_lu_fwd_rule(x, activation_type)
return geglu_output
return _output
def _geglu_fwd_rule(x):
geglu_output = gated_gelu(x)
return geglu_output, (x,)
def _activation_lu_fwd_rule(x, activation_type):
fwd_output = activation_dict[activation_type]["fwd"](x)
return fwd_output, (x,)
def _geglu_bwd_rule(ctx, g):
def _activation_lu_bwd_rule(activation_type, ctx, g):
x, = ctx
assert x.dtype == g.dtype
dx = dgated_gelu(g, x)
dx = activation_dict[activation_type]["bwd"](g, x)
dx = jnp.reshape(dx, x.shape)
return (dx,)
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
_geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule)
def layernorm_geglu_fp8_mlp(x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
kernels: List[jnp.ndarray],
fp8_gemm_pkg: FP8MetaPackage,
layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
layernorm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None,
ffn1_ckpt_name: str = 'ffn1',
ffn2_ckpt_name: str = 'ffn2') -> jnp.ndarray:
"""
Layernorm + GEMM1 + GeGLU + GEMM2
"""
assert len(kernels) == 2
assert fp8_gemm_pkg.num_of_gemm == len(kernels)
kernel_1 = kernels[0]
kernel_2 = kernels[1]
fp8_max = fp8_gemm_pkg.fp8_max
amax = fp8_gemm_pkg.amax
scale = fp8_gemm_pkg.scale
scale_inv = fp8_gemm_pkg.scale_inv
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm':
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
output = _layernorm_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale,
scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
zero_centered_gamma, epsilon, layernorm_input_axes,
dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name,
ffn2_ckpt_name)
return output
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
def _layernorm_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray,
amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str,
zero_centered_gamma: bool, epsilon: float,
layernorm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
ffn1_ckpt_name: str, ffn2_ckpt_name: str):
output, _ = _layernorm_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax,
scale, scale_inv, fwd_dtype, bwd_dtype,
layernorm_type, zero_centered_gamma, epsilon,
layernorm_input_axes, dot_1_input_axes,
dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name)
return output
def _layernorm_geglu_fp8_mlp_fwd_rule(
x,
gamma,
beta,
kernel_1,
kernel_2,
fp8_max,
amax,
scale,
scale_inv,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
layernorm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name,
ffn2_ckpt_name):
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (Hidden_in, 2, Hidden_out)
# Kernel_2 should be in shape of (Hidden_in, Hidden_out)
assert len(kernel_1.shape) == 3
assert kernel_1.shape[-2] == 2
assert len(kernel_2.shape) == 2
x_contracting_dims = (len(x.shape) - 1,)
xt_batch_dims = tuple(range(1, x.ndim))
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
assert kernel_1.shape[-1] == kernel_2.shape[0]
amax = FP8Helper.update_amax_history(amax)
gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
x_amax = amax[gemm1_x_idx, 0:1]
x_scale = scale[gemm1_x_idx]
x_scale_inv = scale_inv[gemm1_x_idx]
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
if layernorm_type == 'layernorm':
ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
x,
gamma,
beta,
x_amax,
x_scale,
x_scale_inv,
out_dtype=fwd_dtype,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
gamma,
x_amax,
x_scale,
x_scale_inv,
out_dtype=fwd_dtype,
epsilon=epsilon)
mu = None
assert x.shape == ln_out.shape
kernel_1_amax = amax[gemm1_kernel_idx, 0:1]
kernel_1_scale = scale[gemm1_kernel_idx]
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel_1, updated_kernel_1_amax = \
cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)
ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)
# (batch..., hidden_in) x (hidden_in, 2, hidden_out)
dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
(x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
geglu_out_amax = amax[gemm2_x_idx, 0:1]
geglu_out_scale = scale[gemm2_x_idx]
geglu_out_scale_inv = scale_inv[gemm2_x_idx]
# (batch..., hidden_in) -> (batch..., hidden)
casted_geglu_out, updated_geglu_amax = gated_gelu_fp8(dot_1_output, geglu_out_amax,
geglu_out_scale, geglu_out_scale_inv,
fwd_dtype)
casted_geglu_out = with_sharding_constraint_by_logical_axes(casted_geglu_out, dot_2_input_axes)
kernel_2_scale = scale[gemm2_kernel_idx]
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)
# (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = fp8_dot_impl(casted_geglu_out, casted_kernel_2, geglu_out_scale_inv,
kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, casted_kernel_1,
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_geglu_amax,
updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims)
return dot_2_output, ctx
def _layernorm_geglu_fp8_mlp_bwd_rule(
fwd_dtype, # pylint: disable=unused-argument
bwd_dtype,
layernorm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name, # pylint: disable=unused-argument
ffn2_ckpt_name, # pylint: disable=unused-argument
ctx,
grad):
x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, \
casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
updated_geglu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
x_contracting_dims, xt_batch_dims = ctx
gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)
grad_amax = amax[gemm2_grad_idx, 0:1]
grad_scale = scale[gemm2_grad_idx]
grad_scale_inv = scale_inv[gemm2_grad_idx]
# Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=-1)
casted_geglu_out_t = transpose(casted_geglu_out,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
# (hidden, batch...,) x (hidden, batch...)
gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
wgrad_2 = fp8_dot_impl(casted_geglu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv,
grad.dtype, (xt_batch_dims, xt_batch_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# (batch..., hidden_out) x (hidden_in, hidden_out)
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv,
grad.dtype, (x_contracting_dims, (1,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)
dgeglu_amax = amax[gemm1_grad_idx, 0:1]
dgeglu_scale = scale[gemm1_grad_idx]
dgeglu_scale_inv = scale_inv[gemm1_grad_idx]
casted_dgeglu, casted_dgeglu_t, updated_dgeglu_amax = dgated_gelu_cast_transpose(
dgrad_2,
dot_1_output,
dgeglu_amax,
dgeglu_scale,
dgeglu_scale_inv,
bwd_dtype,
static_axis_boundary=-1)
ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
# (hidden, batch...) x (2, hidden, batch...)
xt_batch_dims_plus_act_dim = tuple(i + 1 for i in xt_batch_dims)
gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dgeglu_t, gemm1_x_scale_inv, dgeglu_scale_inv,
grad.dtype, (xt_batch_dims, xt_batch_dims_plus_act_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# (batch..., 2, hidden_out) x (hidden_in, 2, hidden_out)
x_contracting_dims_plus_act_dim = (min(x_contracting_dims),) + tuple(
i + 1 for i in x_contracting_dims)
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
dgrad_1 = fp8_dot_impl(casted_dgeglu, casted_kernel_1, dgeglu_scale_inv, kernel_1_scale_inv,
grad.dtype, (x_contracting_dims_plus_act_dim, (1, 2)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)
if layernorm_type == 'layernorm':
dx, dgamma, dbeta = layernorm_bwd(dgrad_1,
x,
mu,
rsigma,
gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0])
amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0])
amax = amax.at[gemm1_grad_idx, 0].set(updated_dgeglu_amax[0])
amax = amax.at[gemm2_x_idx, 0].set(updated_geglu_amax[0])
amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax)
amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
return dx, dgamma, dbeta, wgrad_1, wgrad_2, \
fp8_max, amax, scale, scale_inv
_layernorm_geglu_fp8_mlp.defvjp(_layernorm_geglu_fp8_mlp_fwd_rule,
_layernorm_geglu_fp8_mlp_bwd_rule)
def layernorm_gelu_fp8_mlp(x: jnp.ndarray,
def fused_layernorm_fp8_mlp(x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
kernels: List[jnp.ndarray],
......@@ -398,9 +85,11 @@ def layernorm_gelu_fp8_mlp(x: jnp.ndarray,
dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None,
ffn1_ckpt_name: str = 'ffn1',
ffn2_ckpt_name: str = 'ffn2') -> jnp.ndarray:
ffn2_ckpt_name: str = 'ffn2',
activation_type: Sequence[Union[str, Callable]] = ('gelu',),
use_bias: bool = True) -> jnp.ndarray:
"""
Layernorm + GEMM1 + bias + GeLU + GEMM2 + bias
Layernorm + GEMM1 + bias + activation + GEMM2 + bias
"""
assert len(kernels) == 2
......@@ -424,32 +113,36 @@ def layernorm_gelu_fp8_mlp(x: jnp.ndarray,
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
output = _layernorm_gelu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max,
output = _fused_layernorm_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max,
amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
zero_centered_gamma, epsilon, layernorm_input_axes,
dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name,
ffn2_ckpt_name)
ffn2_ckpt_name, activation_type, use_bias)
return output
@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20))
def _layernorm_gelu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22))
def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray,
bias_2: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool,
epsilon: float, layernorm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
ffn1_ckpt_name: str, ffn2_ckpt_name: str):
output, _ = _layernorm_gelu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2,
fp8_max, amax, scale, scale_inv, fwd_dtype,
bwd_dtype, layernorm_type, zero_centered_gamma,
epsilon, layernorm_input_axes, dot_1_input_axes,
dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name)
ffn1_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]],
use_bias: bool):
output, _ = _fused_layernorm_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1,
bias_2, fp8_max, amax, scale, scale_inv,
fwd_dtype, bwd_dtype, layernorm_type,
zero_centered_gamma, epsilon,
layernorm_input_axes, dot_1_input_axes,
dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name,
activation_type, use_bias)
return output
def _layernorm_gelu_fp8_mlp_fwd_rule(
def _fused_layernorm_fp8_mlp_fwd_rule(
x,
gamma,
beta,
......@@ -470,13 +163,16 @@ def _layernorm_gelu_fp8_mlp_fwd_rule(
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name,
ffn2_ckpt_name):
ffn2_ckpt_name,
activation_type,
use_bias):
is_gated = len(activation_type) > 1
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
# Kernel_2 should be in shape of (Hidden_in, Hidden_out)
assert len(kernel_1.shape) == 3
assert kernel_1.shape[-2] == 1
assert kernel_1.shape[-2] == len(activation_type)
assert len(kernel_2.shape) == 2
x_contracting_dims = (len(x.shape) - 1,)
......@@ -487,7 +183,8 @@ def _layernorm_gelu_fp8_mlp_fwd_rule(
# Squeeze act axis
# (hidden_in, 1, hidden_out) -> (hidden_in, hidden_out)
kernel_1 = jnp.squeeze(kernel_1, axis=-2)
if not is_gated:
kernel_1 = jnp.squeeze(kernel_1, axis=-2)
amax = FP8Helper.update_amax_history(amax)
......@@ -539,22 +236,26 @@ def _layernorm_gelu_fp8_mlp_fwd_rule(
dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
(x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
bias_1_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1.shape
dot_1_output += jnp.reshape(bias_1, bias_1_shape)
if use_bias:
bias_1_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1.shape
dot_1_output += jnp.reshape(bias_1, bias_1_shape)
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
gelu_out_amax = amax[gemm2_x_idx, 0:1]
gelu_out_scale = scale[gemm2_x_idx]
gelu_out_scale_inv = scale_inv[gemm2_x_idx]
activation_lu_out_amax = amax[gemm2_x_idx, 0:1]
activation_lu_out_scale = scale[gemm2_x_idx]
activation_lu_out_scale_inv = scale_inv[gemm2_x_idx]
activation_lu_fp8 = activation_fp8_dict[activation_type]["fwd"]
# (batch..., hidden_in) -> (batch..., hidden)
casted_gelu_out, updated_gelu_amax = gelu_fp8(dot_1_output, gelu_out_amax, gelu_out_scale,
gelu_out_scale_inv, fwd_dtype)
casted_activation_lu_out, updated_activation_lu_amax = activation_lu_fp8(dot_1_output,
activation_lu_out_amax, activation_lu_out_scale,
activation_lu_out_scale_inv, fwd_dtype)
casted_gelu_out = with_sharding_constraint_by_logical_axes(casted_gelu_out, dot_2_input_axes)
casted_activation_lu_out = with_sharding_constraint_by_logical_axes(casted_activation_lu_out,
dot_2_input_axes)
kernel_2_scale = scale[gemm2_kernel_idx]
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
......@@ -563,23 +264,26 @@ def _layernorm_gelu_fp8_mlp_fwd_rule(
casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)
# (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = fp8_dot_impl(casted_gelu_out, casted_kernel_2, gelu_out_scale_inv,
dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2,
activation_lu_out_scale_inv,
kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
bias_2_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2.shape
dot_2_output += jnp.reshape(bias_2, bias_2_shape)
if use_bias:
bias_2_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2.shape
dot_2_output += jnp.reshape(bias_2, bias_2_shape)
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_gelu_out, casted_kernel_1,
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_gelu_amax,
updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims,
bias_1.shape, bias_2.shape)
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1,
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax,
x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape)
return dot_2_output, ctx
def _layernorm_gelu_fp8_mlp_bwd_rule(
def _fused_layernorm_fp8_mlp_bwd_rule(
fwd_dtype, # pylint: disable=unused-argument
bwd_dtype,
layernorm_type,
......@@ -590,13 +294,17 @@ def _layernorm_gelu_fp8_mlp_bwd_rule(
dot_2_input_axes,
ffn1_ckpt_name, # pylint: disable=unused-argument
ffn2_ckpt_name, # pylint: disable=unused-argument
activation_type,
use_bias,
ctx,
grad):
x, ln_out, mu, rsigma, gamma, dot_1_output, casted_gelu_out, \
x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, \
casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
updated_gelu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape= ctx
is_gated = len(activation_type) > 1
gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)
grad_amax = amax[gemm2_grad_idx, 0:1]
......@@ -606,21 +314,29 @@ def _layernorm_gelu_fp8_mlp_bwd_rule(
# Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=-1)
casted_gelu_out_t = transpose(casted_gelu_out,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
if use_bias:
casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \
dbias_cast_transpose(grad, grad_amax, grad_scale,
grad_scale_inv, bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
dbias_2 = jnp.reshape(dbias_2, bias_2_shape)
else:
casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale,
grad_scale_inv, bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
dbias_2 = jnp.empty(bias_2_shape, grad.dtype)
dbias_2 = jnp.sum(grad, axis=(i for i in range(grad.ndim - 1)))
dbias_2 = jnp.reshape(dbias_2, bias_2_shape)
casted_activation_lu_out_t = transpose(casted_activation_lu_out,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
# (hidden, batch...,) x (hidden, batch...)
gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
wgrad_2 = fp8_dot_impl(casted_gelu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv,
grad.dtype, (xt_batch_dims, xt_batch_dims),
wgrad_2 = fp8_dot_impl(casted_activation_lu_out_t, casted_grad_t, gemm2_x_scale_inv,
grad_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# (batch..., hidden_out) x (hidden_in, hidden_out)
......@@ -633,36 +349,85 @@ def _layernorm_gelu_fp8_mlp_bwd_rule(
gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)
dgelu_amax = amax[gemm1_grad_idx, 0:1]
dgelu_scale = scale[gemm1_grad_idx]
dgelu_scale_inv = scale_inv[gemm1_grad_idx]
casted_dgelu, casted_dgelu_t, dbias_1, updated_dgelu_amax = dgelu_dbias_cast_transpose(
dgrad_2,
dot_1_output,
dgelu_amax,
dgelu_scale,
dgelu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
dactivation_lu_amax = amax[gemm1_grad_idx, 0:1]
dactivation_lu_scale = scale[gemm1_grad_idx]
dactivation_lu_scale_inv = scale_inv[gemm1_grad_idx]
dactivation_lu_cast_transpose = activation_fp8_dict[activation_type]["bwd"]
dactivation_lu = activation_dict[activation_type]["bwd"](dgrad_2, dot_1_output)
if is_gated:
if use_bias:
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
dbias_cast_transpose(
dactivation_lu,
dactivation_lu_amax,
dactivation_lu_scale,
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-2)
dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else:
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
dactivation_lu_cast_transpose(
dgrad_2,
dot_1_output,
dactivation_lu_amax,
dactivation_lu_scale,
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1)
dbias_1 = jnp.empty(bias_1_shape, bwd_dtype)
else:
if use_bias:
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
dactivation_lu_cast_transpose(
dgrad_2,
dot_1_output,
dactivation_lu_amax,
dactivation_lu_scale,
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else:
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
cast_transpose(
dactivation_lu,
dactivation_lu_amax,
dactivation_lu_scale,
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
dbias_1 = jnp.empty(bias_1_shape, bwd_dtype)
ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
# (hidden, batch...) x (hidden, batch...)
gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dgelu_t, gemm1_x_scale_inv, dgelu_scale_inv, grad.dtype,
(xt_batch_dims, xt_batch_dims),
xt_batch_dims_2 = xt_batch_dims if not is_gated \
else tuple(i + 1 for i in xt_batch_dims)
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv,
dactivation_lu_scale_inv, grad.dtype,
(xt_batch_dims, xt_batch_dims_2),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# Expand act axis to match the shape with the given kernel_1
wgrad_1 = jnp.expand_dims(wgrad_1, axis=-2)
if not is_gated:
wgrad_1 = jnp.expand_dims(wgrad_1, axis=-2)
# (batch..., hidden_out) x (hidden_in, hidden_out)
if is_gated:
x_contracting_dims = ((min(x_contracting_dims),) + tuple(
i + 1 for i in x_contracting_dims), (1,2))
else:
x_contracting_dims = (x_contracting_dims, (1,))
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
dgrad_1 = fp8_dot_impl(casted_dgelu, casted_kernel_1, dgelu_scale_inv, kernel_1_scale_inv,
grad.dtype, (x_contracting_dims, (1,)),
dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1,
dactivation_lu_scale_inv, kernel_1_scale_inv,
grad.dtype, x_contracting_dims,
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)
......@@ -683,15 +448,15 @@ def _layernorm_gelu_fp8_mlp_bwd_rule(
amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0])
amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0])
amax = amax.at[gemm1_grad_idx, 0].set(updated_dgelu_amax[0])
amax = amax.at[gemm2_x_idx, 0].set(updated_gelu_amax[0])
amax = amax.at[gemm1_grad_idx, 0].set(updated_dactivation_lu_amax[0])
amax = amax.at[gemm2_x_idx, 0].set(updated_activation_lu_amax[0])
amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax)
amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \
fp8_max, amax, scale, scale_inv
_layernorm_gelu_fp8_mlp.defvjp(_layernorm_gelu_fp8_mlp_fwd_rule, _layernorm_gelu_fp8_mlp_bwd_rule)
_fused_layernorm_fp8_mlp.defvjp(_fused_layernorm_fp8_mlp_fwd_rule,
_fused_layernorm_fp8_mlp_bwd_rule)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment