Unverified Commit dac00019 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Unifying GeLU and GeGLU in LayerNorm MLP (#765)



* combined layernorm_geglu with layernorm_gelu into fused_layernorm
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fixes to pass all unit tests in test_custom_call_compute.py,
test_layer.py, and test_praxis_layer.py
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* cleaning and formatting
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* renaming based on reviewers suggestions
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* implemented partial fused layernorm
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* geglu + bias passed tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* added partial fused calculation for dbias_1
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* clean up
Co-authored-by: default avatarAlp Dener <adener@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
Co-authored-by: default avatarAlp Dener <adener@nvidia.com>
parent 07bf4acf
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import functools import functools
import operator import operator
from typing import Callable, Sequence, Union
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -22,8 +23,7 @@ from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quanti ...@@ -22,8 +23,7 @@ from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quanti
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp from transformer_engine.jax.mlp import fused_layernorm_fp8_mlp
from transformer_engine.jax.mlp import layernorm_gelu_fp8_mlp
GEMM_CASES = [ GEMM_CASES = [
(256, 256, 512), (256, 256, 512),
...@@ -174,17 +174,32 @@ class TestFP8Dot: ...@@ -174,17 +174,32 @@ class TestFP8Dot:
assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024), @pytest.mark.parametrize('m,n,k', [(256, 512, 128), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)]) (16384, 1024, 1024)])
def test_grad_ln_geglu_fp8_mlp(self, m, n, k): @pytest.mark.parametrize('activation_type', [('gelu', ),
('gelu', 'linear')])
@pytest.mark.parametrize('use_bias', [True, False])
def test_grad_fused_layernorm_fp8_mlp(self, m, n, k,
activation_type: Sequence[Union[str, Callable]],
use_bias: bool):
""" N/a """
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 4) subkeys = jax.random.split(key, 6)
activations = ('gelu', 'linear')
activation_dict = {
('gelu', ): jax.nn.gelu
}
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16) a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16) k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16)
k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
s = jax.random.normal(subkeys[3], (k,), jnp.bfloat16) s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
else:
b1 = jax.random.normal(subkeys[3], (0,), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (0,), jnp.bfloat16)
init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2) init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
init_fp8_metas_amax = jnp.zeros( init_fp8_metas_amax = jnp.zeros(
...@@ -192,14 +207,16 @@ class TestFP8Dot: ...@@ -192,14 +207,16 @@ class TestFP8Dot:
init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
def primitive_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale, def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv): fp8_metas_scale_inv):
# x is input tensor, matrix 2d # x is input tensor, matrix 2d
# y, z are weights, matrix 2d # y, z are weights, matrix 2d
# out = (x * y) * z # out = ((x * y) + w) * z + v
fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) fp8_metas_scale_inv)
return jnp.mean(layernorm_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm")) return jnp.mean(
fused_layernorm_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm",
activation_type = activation_type, use_bias = use_bias))
def _convert_to_activation_function(fn_or_string): def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function.""" """Convert a string to an activation function."""
...@@ -211,115 +228,7 @@ class TestFP8Dot: ...@@ -211,115 +228,7 @@ class TestFP8Dot:
return fn_or_string return fn_or_string
raise ValueError(f"don't know how to convert {fn_or_string} to an activation function") raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")
def ln_geglu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray, def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray) -> jnp.ndarray:
x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
ln_out = y * ln_scale
ln_out = jnp.asarray(ln_out, jnp.bfloat16)
fp8_gemm_1_pkg = FP8MetaPackage(1, fp8_maxs[:FP8Helper.NUM_META_PER_GEMM],
amax[:FP8Helper.NUM_META_PER_GEMM],
scale[:FP8Helper.NUM_META_PER_GEMM],
scale_inv[:FP8Helper.NUM_META_PER_GEMM])
linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,)))
x = jnp.split(linear_1_out, len(activations), axis=-2)
acts = []
for idx, act_fn in enumerate(activations):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = functools.reduce(operator.mul, acts)
x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)
fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
amax[FP8Helper.NUM_META_PER_GEMM:],
scale[FP8Helper.NUM_META_PER_GEMM:],
scale_inv[FP8Helper.NUM_META_PER_GEMM:])
output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,)))
return output
def ref_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv):
return jnp.mean(
ln_geglu_fp8_mlp_ref(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv))
value_n_grad_primitive_func = jit(value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7)))
value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7)))
ref_fp8_max = init_fp8_max
ref_fp8_metas_amax = init_fp8_metas_amax
ref_fp8_metas_scale = init_fp8_metas_scale
ref_fp8_metas_scale_inv = init_fp8_metas_scale_inv
pri_fp8_max = init_fp8_max
pri_fp8_metas_amax = init_fp8_metas_amax
pri_fp8_metas_scale = init_fp8_metas_scale
pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv
for _ in range(3):
ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_fp8_max,
ref_fp8_metas_amax, ref_fp8_metas_scale,
ref_fp8_metas_scale_inv) = value_n_grad_ref_func(
a, s, k1, k2, ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale,
ref_fp8_metas_scale_inv)
for _ in range(3):
primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
primitive_k2_grad, pri_fp8_max, pri_fp8_metas_amax, pri_fp8_metas_scale,
pri_fp8_metas_scale_inv) = value_n_grad_primitive_func(
a, s, k1, k2, pri_fp8_max, pri_fp8_metas_amax, pri_fp8_metas_scale,
pri_fp8_metas_scale_inv)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
jnp.asarray(ref_a_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)])
def test_grad_ln_gelu_fp8_mlp(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 6)
activations = ('gelu',)
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16)
k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
b1 = jax.random.normal(subkeys[3], (len(activations), n), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
init_fp8_metas_amax = jnp.zeros(
(FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), jnp.float32)
init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
# x is input tensor, matrix 2d
# y, z are weights, matrix 2d
# out = ((x * y) + w) * z + v
fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.mean(
layernorm_gelu_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm"))
def ln_gelu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray,
fp8_maxs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray) -> jnp.ndarray: scale_inv: jnp.ndarray) -> jnp.ndarray:
...@@ -336,10 +245,20 @@ class TestFP8Dot: ...@@ -336,10 +245,20 @@ class TestFP8Dot:
scale_inv[:FP8Helper.NUM_META_PER_GEMM]) scale_inv[:FP8Helper.NUM_META_PER_GEMM])
linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,))) linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,)))
if use_bias:
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape) linear_1_out += jnp.reshape(bias_1, bias_1_shape)
x = jax.nn.gelu(linear_1_out) if 'linear' in activation_type:
x = jnp.split(linear_1_out, len(activation_type), axis=-2)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = functools.reduce(operator.mul, acts)
else:
x = activation_dict[activation_type](linear_1_out)
x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16) x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)
fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:], fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
...@@ -348,6 +267,7 @@ class TestFP8Dot: ...@@ -348,6 +267,7 @@ class TestFP8Dot:
scale_inv[FP8Helper.NUM_META_PER_GEMM:]) scale_inv[FP8Helper.NUM_META_PER_GEMM:])
output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,))) output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,)))
if use_bias:
bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
output += jnp.reshape(bias_2, bias_2_shape) output += jnp.reshape(bias_2, bias_2_shape)
...@@ -356,7 +276,7 @@ class TestFP8Dot: ...@@ -356,7 +276,7 @@ class TestFP8Dot:
def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv): fp8_metas_scale_inv):
return jnp.mean( return jnp.mean(
ln_gelu_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, layernorm_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)) fp8_metas_scale_inv))
value_n_grad_primitive_func = jit( value_n_grad_primitive_func = jit(
...@@ -373,6 +293,7 @@ class TestFP8Dot: ...@@ -373,6 +293,7 @@ class TestFP8Dot:
pri_fp8_metas_scale = init_fp8_metas_scale pri_fp8_metas_scale = init_fp8_metas_scale
pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv
# Convert str to index as str is not a valid type for JAX JIT
for _ in range(3): for _ in range(3):
ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad, ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad,
ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale, ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale,
...@@ -401,6 +322,7 @@ class TestFP8Dot: ...@@ -401,6 +322,7 @@ class TestFP8Dot:
assert_allclose(jnp.asarray(primitive_s_grad, np.float32), assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32), jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE) dtype=FP8Helper.BWD_DTYPE)
if use_bias:
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32), assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32), jnp.asarray(ref_b1_grad, np.float32),
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
...@@ -409,6 +331,7 @@ class TestFP8Dot: ...@@ -409,6 +331,7 @@ class TestFP8Dot:
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
@pytest.fixture(name="random_inputs") @pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape): def random_inputs_fixture(shape):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
......
...@@ -529,10 +529,11 @@ void cast_transpose_dbias(const Tensor &input, ...@@ -529,10 +529,11 @@ void cast_transpose_dbias(const Tensor &input,
Tensor *dbias, Tensor *dbias,
Tensor *workspace, Tensor *workspace,
cudaStream_t stream) { cudaStream_t stream) {
CheckInputTensor(input, "cast_transpose_dbias_input"); // TODO
CheckOutputTensor(*cast_output, "cast_output"); // CheckInputTensor(input, "cast_transpose_dbias_input");
CheckOutputTensor(*transposed_output, "transposed_output"); // CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*dbias, "dbias"); // CheckOutputTensor(*transposed_output, "transposed_output");
// CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
......
...@@ -4334,6 +4334,231 @@ def dgelu_dbias_cast_transpose( ...@@ -4334,6 +4334,231 @@ def dgelu_dbias_cast_transpose(
transpose_axis_boundary=transpose_axis_boundary) transpose_axis_boundary=transpose_axis_boundary)
class DBiasCastTransposePrimitive(BasePrimitive):
"""
DBias Cast Transpose Primitive
"""
name = "te_dbias_cast_transpose"
multiple_results = True
# out_dtype, static_axis_boundary, transpose_axis_boundary
impl_static_args = (4, 5, 6)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(dz_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
static_axis_boundary, transpose_axis_boundary):
"""
te_dbias_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
gi_hidden_size = dz_aval.shape[-1]
t_shape = _multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary)
out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
if dz_aval.shape[-2] == 2:
gi_hidden_size *= 2
dbias_shape = (*dz_aval.shape[:static_axis_boundary + 1], gi_hidden_size)
dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
wkspace_info, = transformer_engine_jax.get_dbias_ct_workspace_sizes(
dz_aval.size // gi_hidden_size,
gi_hidden_size,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype)
)
wkspace_aval = dz_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
return out, t_out, dbias, updated_amax_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
te_dbias_cast_transpose_p outer abstract
"""
out, t_out, dbias, updated_amax_aval, _ = \
DBiasCastTransposePrimitive.abstract(*args, **kwargs)
return out, t_out, dbias, updated_amax_aval
@staticmethod
def lowering(ctx, dz, amax, scale, scale_inv, *, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
te_dbias_cast_transpose_p lowering rules
"""
dz_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
ir_hidden_szie = ir_dz_shape[-1]
if dz_aval.shape[-2] == 2:
batch_szie = reduce(operator.mul, ir_dz_shape[:-2])
ir_hidden_szie *= 2
else:
batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
contracted_dz_shape = (batch_szie, ir_hidden_szie)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_dz_shape = _multidim_transpose(ir_dz_shape, static_axis_boundary,
transpose_axis_boundary)
dbias_shape = (*ir_dz_shape[:static_axis_boundary + 1], ir_hidden_szie)
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(ir_dz_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype),
ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
]
operands = [dz, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_dz_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype))
out = custom_caller(DBiasCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={1: 3})
return out
@staticmethod
def impl(dz, amax, scale, scale_inv, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
to describe implementation
"""
assert DBiasCastTransposePrimitive.inner_primitive is not None
out, t_out, dbias, updated_amax, _ = DBiasCastTransposePrimitive.inner_primitive.bind(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
return out, t_out, dbias, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
to describe batch rules for vmap
"""
del static_axis_boundary
_check_valid_batch_dims(batch_dims)
assert DBiasCastTransposePrimitive.outer_primitive is not None
dz, amax, scale, scale_inv = batched_args
dz_bdim, _, amax_bdim, _, _ = batch_dims
# Minus batch dim.
transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, dz.ndim - 1)
transpose_axis_boundary += 1 # Plus batch dim
out_bdims = dz_bdim, dz_bdim, dz_bdim, amax_bdim
return DBiasCastTransposePrimitive.outer_primitive.bind(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=dz_bdim,
transpose_axis_boundary=transpose_axis_boundary), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh,
arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)
@staticmethod
def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding,
amax_sharding)
def sharded_impl(dz, amax, scale, scale_inv):
local_out, local_t_out, local_dbias, local_amax = DBiasCastTransposePrimitive.impl(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_out, local_t_out, global_dbias, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DBiasCastTransposePrimitive)
def dbias_cast_transpose(
dz: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType,
static_axis_boundary: int,
transpose_axis_boundary: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose dbias partial fusion wrapper
Return FP8(inputs), dbias
"""
if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes
return DBiasCastTransposePrimitive.outer_primitive.bind(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
class GatedGeluFp8Primitive(BasePrimitive): class GatedGeluFp8Primitive(BasePrimitive):
""" """
Gated Gelu FP8 Primitive Gated Gelu FP8 Primitive
......
...@@ -29,6 +29,7 @@ pybind11::dict Registrations() { ...@@ -29,6 +29,7 @@ pybind11::dict Registrations() {
dict["te_gelu_fp8"] = EncapsulateFunction(GeluFP8); dict["te_gelu_fp8"] = EncapsulateFunction(GeluFP8);
dict["te_dgelu"] = EncapsulateFunction(DGelu); dict["te_dgelu"] = EncapsulateFunction(DGelu);
dict["te_dgelu_dbias_cast_transpose"] = EncapsulateFunction(DGeluDBiasCastTranspose); dict["te_dgelu_dbias_cast_transpose"] = EncapsulateFunction(DGeluDBiasCastTranspose);
dict["te_dbias_cast_transpose"] = EncapsulateFunction(DBiasCastTranspose);
dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu); dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu);
dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8); dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8);
dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu); dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu);
...@@ -66,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -66,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_device_compute_capability", &GetDeviceComputeCapability); m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_cublasLt_version", &cublasLtGetVersion); m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes); m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes);
m.def("get_dbias_ct_workspace_sizes", &GetDBiasCastTransposeWorkspaceSizes);
m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes); m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes);
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes); m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
......
...@@ -301,6 +301,69 @@ void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *op ...@@ -301,6 +301,69 @@ void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *op
dbias_tensor.data(), workspace.data(), stream); dbias_tensor.data(), workspace.data(), stream);
} }
// HERE
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
TensorWrapper dummy_workspace;
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
auto *output_trans = buffers[5];
auto *dbias = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
void *workspace_ptr = buffers[8];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), dbias_tensor.data(),
workspace.data(), stream);
}
void GatedGeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, void GatedGeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output) { cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
auto input_shape = std::vector<size_t>{m, n * 2}; auto input_shape = std::vector<size_t>{m, n * 2};
......
...@@ -152,6 +152,12 @@ pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size ...@@ -152,6 +152,12 @@ pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size
void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
...@@ -22,8 +22,7 @@ from ..dot import type_safe_dot_general ...@@ -22,8 +22,7 @@ from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage from ..fp8 import FP8Helper, FP8MetaPackage
from ..layernorm import canonicalize_layernorm_type from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot from ..layernorm import layernorm, layernorm_fp8_dot
from ..mlp import layernorm_geglu_fp8_mlp, geglu from ..mlp import fused_layernorm_fp8_mlp, activation_lu
from ..mlp import layernorm_gelu_fp8_mlp, gelu
from ..softmax import is_softmax_kernel_available from ..softmax import is_softmax_kernel_available
from ..softmax import softmax, SoftmaxType from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes from ..sharding import with_sharding_constraint_by_logical_axes
...@@ -944,35 +943,22 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -944,35 +943,22 @@ class LayerNormMLP(TransformerEngineBase):
fuse_layernorm = FP8Helper.is_fp8_enabled( fuse_layernorm = FP8Helper.is_fp8_enabled(
) and not self.return_layernorm_output and self.enable_layernorm ) and not self.return_layernorm_output and self.enable_layernorm
def is_geglu(acts): # Make sure each tuple is sorted in alphabet order
geglu_act_pool = [('gelu', 'linear'), ('linear', 'gelu')] gated_act_pool = [('gelu', 'linear')]
#('linear', 'silu')] coming
normalize_acts = [] act_pool = [('gelu',)]
for act in acts: #('silu',)] coming
if not isinstance(act, str):
return False
normalize_acts.append(act.lower())
return tuple(normalize_acts) in geglu_act_pool
def is_gelu(acts):
geglu_act_pool = [('gelu',)]
normalize_acts = [] normalize_acts = []
for act in acts: for act in self.activations:
if not isinstance(act, str): if not isinstance(act, str):
return False return False
normalize_acts.append(act.lower()) normalize_acts.append(act.lower())
return tuple(normalize_acts) in geglu_act_pool normalize_acts = tuple(sorted(normalize_acts))
is_gated = normalize_acts in gated_act_pool
use_fused_ln_geglu_mlp = fuse_layernorm \ is_act_implemented = normalize_acts in (gated_act_pool + act_pool)
and (not self.use_bias) and is_geglu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3) \
and not self.enable_low_rank_adaptation
use_fused_ln_gelu_mlp = fuse_layernorm \ use_fused_layernorm_mlp = fuse_layernorm and is_act_implemented and\
and self.use_bias and is_gelu(self.activations) \ self.intermediate_dropout_rate < 1e-3
and (self.intermediate_dropout_rate < 1e-3) \
and not self.enable_low_rank_adaptation
# LayerNorm # LayerNorm
if self.enable_layernorm: if self.enable_layernorm:
...@@ -1045,38 +1031,26 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1045,38 +1031,26 @@ class LayerNormMLP(TransformerEngineBase):
ffn1_ckpt_name = 'ffn1' ffn1_ckpt_name = 'ffn1'
ffn2_ckpt_name = 'ffn2' ffn2_ckpt_name = 'ffn2'
if use_fused_ln_geglu_mlp: if use_fused_layernorm_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
out = layernorm_geglu_fp8_mlp(y,
scale,
ln_bias, [kernel_1, kernel_2],
fp8_meta_package,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name)
elif use_fused_ln_gelu_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment assert self.axis == -1 # Only support axis = =-1 at this moment
bias_1_shape = intermediate_dim if self.use_bias else 0
bias_1 = nn_partitioning.param_with_axes('wi_bias', bias_1 = nn_partitioning.param_with_axes('wi_bias',
self.bias_init, self.bias_init,
intermediate_dim, bias_1_shape,
jnp.float32, jnp.float32,
axes=self.bias_axes_1) axes=self.bias_axes_1)
bias_1 = bias_1.astype(self.dtype) bias_1 = bias_1.astype(self.dtype)
bias_2_shape = (hidden_size,) if self.use_bias else (0,)
bias_2 = nn_partitioning.param_with_axes('wo_bias', bias_2 = nn_partitioning.param_with_axes('wo_bias',
self.bias_init, (hidden_size,), self.bias_init,
bias_2_shape,
jnp.float32, jnp.float32,
axes=self.bias_axes_2) axes=self.bias_axes_2)
bias_2 = bias_2.astype(self.dtype) bias_2 = bias_2.astype(self.dtype)
out = layernorm_gelu_fp8_mlp(y, out = fused_layernorm_fp8_mlp(y,
scale, scale,
ln_bias, [kernel_1, kernel_2], [bias_1, bias_2], ln_bias, [kernel_1, kernel_2], [bias_1, bias_2],
fp8_meta_package, fp8_meta_package,
...@@ -1087,9 +1061,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1087,9 +1061,10 @@ class LayerNormMLP(TransformerEngineBase):
dot_1_input_axes=self.dot_1_input_axes, dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes, dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name, ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name) ffn2_ckpt_name=ffn2_ckpt_name,
activation_type = normalize_acts,
use_bias = self.use_bias)
else: # not use_fused_ln_geglu_mlp else: # not use_fused_ln_geglu_mlp
# DenseGeneral 1 # DenseGeneral 1
gemm1_fp8_meta_package = None if fp8_meta_package is None \ gemm1_fp8_meta_package = None if fp8_meta_package is None \
else fp8_meta_package.get_package_by_gemm_idx(0) else fp8_meta_package.get_package_by_gemm_idx(0)
...@@ -1142,31 +1117,29 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1142,31 +1117,29 @@ class LayerNormMLP(TransformerEngineBase):
x += _apply_low_rank_adaptation(y, axis, intermediate_dim, wi_lora_a_kernel, x += _apply_low_rank_adaptation(y, axis, intermediate_dim, wi_lora_a_kernel,
wi_lora_b_kernel, self.low_rank_adaptation_alpha) wi_lora_b_kernel, self.low_rank_adaptation_alpha)
bias = None bias_1 = None
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes('wi_bias', bias_1 = nn_partitioning.param_with_axes('wi_bias',
self.bias_init, self.bias_init,
intermediate_dim, intermediate_dim,
jnp.float32, jnp.float32,
axes=self.bias_axes_1) axes=self.bias_axes_1)
bias = bias.astype(self.dtype) bias_1 = bias_1.astype(self.dtype)
bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
x += jnp.reshape(bias, bias_shape) x += jnp.reshape(bias_1, bias_1_shape)
x = checkpoint_name(x, ffn1_ckpt_name) x = checkpoint_name(x, ffn1_ckpt_name)
activations = [] activations = []
if is_geglu(self.activations): if is_act_implemented:
z = geglu(x) z = activation_lu(x, normalize_acts)
elif is_gelu(self.activations):
z = gelu(x)
z = jnp.reshape(z, (*z.shape[:-2], -1))
else: else:
x = jnp.split(x, num_activations, axis=-2) x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(self.activations): for idx, act_fn in enumerate(self.activations):
x_i = _convert_to_activation_function(act_fn)(x[idx]) x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i) activations.append(x_i)
z = functools.reduce(operator.mul, activations) z = functools.reduce(operator.mul, activations)
if not is_gated:
z = jnp.reshape(z, (*z.shape[:-2], -1)) z = jnp.reshape(z, (*z.shape[:-2], -1))
z = nn.Dropout(rate=self.intermediate_dropout_rate, z = nn.Dropout(rate=self.intermediate_dropout_rate,
...@@ -1207,14 +1180,14 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1207,14 +1180,14 @@ class LayerNormMLP(TransformerEngineBase):
out += _apply_low_rank_adaptation(z, axis, hidden_size_tuple, wo_lora_a_kernel, out += _apply_low_rank_adaptation(z, axis, hidden_size_tuple, wo_lora_a_kernel,
wo_lora_b_kernel, self.low_rank_adaptation_alpha) wo_lora_b_kernel, self.low_rank_adaptation_alpha)
bias = None bias_2 = None
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes('wo_bias', bias_2 = nn_partitioning.param_with_axes('wo_bias',
self.bias_init, (hidden_size,), self.bias_init, (hidden_size,),
jnp.float32, jnp.float32,
axes=self.bias_axes_2) axes=self.bias_axes_2)
bias = bias.astype(self.dtype) bias_2 = bias_2.astype(self.dtype)
out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,)) out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
out = checkpoint_name(out, ffn2_ckpt_name) out = checkpoint_name(out, ffn2_ckpt_name)
......
...@@ -3,15 +3,15 @@ ...@@ -3,15 +3,15 @@
# See LICENSE for license information. # See LICENSE for license information.
"""JAX MLP modules""" """JAX MLP modules"""
from typing import List, Tuple from typing import List, Tuple, Sequence, Union, Callable
from functools import partial from functools import partial
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
from .cpp_extensions import cast_fp8, transpose, cast_transpose from .cpp_extensions import cast_fp8, transpose, cast_transpose, dbias_cast_transpose
from .cpp_extensions import gelu as te_gelu from .cpp_extensions import gelu
from .cpp_extensions import gelu_fp8, dgelu, dgelu_dbias_cast_transpose from .cpp_extensions import gelu_fp8, dgelu, dgelu_dbias_cast_transpose
from .cpp_extensions import gated_gelu, gated_gelu_fp8 from .cpp_extensions import gated_gelu, gated_gelu_fp8
from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose
...@@ -23,369 +23,56 @@ from .fp8 import FP8Helper, FP8MetaPackage ...@@ -23,369 +23,56 @@ from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes from .sharding import with_sharding_constraint_by_logical_axes
def gelu(x: jnp.ndarray): activation_dict = {
""" ('gelu',): {'fwd': gelu,
Gelu "bwd": dgelu},
""" ('gelu', 'linear'): {'fwd': gated_gelu,
output = _gelu(x) 'bwd': dgated_gelu}
return output }
@partial(jax.custom_vjp)
def _gelu(x: jnp.ndarray):
geglu_output, _ = _gelu_fwd_rule(x)
return geglu_output
def _gelu_fwd_rule(x):
geglu_output = te_gelu(x)
return geglu_output, (x,)
activation_fp8_dict = {
('gelu',): {'fwd': gelu_fp8,
'bwd': dgelu_dbias_cast_transpose},
('gelu', 'linear'): {'fwd': gated_gelu_fp8,
'bwd': dgated_gelu_cast_transpose}
}
def _gelu_bwd_rule(ctx, g):
x, = ctx
assert x.dtype == g.dtype
dx = dgelu(g, x)
dx = jnp.reshape(dx, x.shape)
return (dx,)
def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
_gelu.defvjp(_gelu_fwd_rule, _gelu_bwd_rule)
def geglu(x: jnp.ndarray):
""" """
Gated gelu Activation Unit
""" """
if len(activation_type) > 1:
assert x.shape[-2] == 2 # Linear + GeLU assert x.shape[-2] == 2 # Linear + GeLU
output = _activation_lu(x, activation_type)
output = _geglu(x)
return output return output
@partial(jax.custom_vjp) @partial(jax.custom_vjp, nondiff_argnums=(1,))
def _geglu(x: jnp.ndarray): def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
geglu_output, _ = _geglu_fwd_rule(x) _output, _ = _activation_lu_fwd_rule(x, activation_type)
return geglu_output return _output
def _geglu_fwd_rule(x): def _activation_lu_fwd_rule(x, activation_type):
geglu_output = gated_gelu(x) fwd_output = activation_dict[activation_type]["fwd"](x)
return geglu_output, (x,) return fwd_output, (x,)
def _geglu_bwd_rule(ctx, g): def _activation_lu_bwd_rule(activation_type, ctx, g):
x, = ctx x, = ctx
assert x.dtype == g.dtype assert x.dtype == g.dtype
dx = dgated_gelu(g, x) dx = activation_dict[activation_type]["bwd"](g, x)
dx = jnp.reshape(dx, x.shape) dx = jnp.reshape(dx, x.shape)
return (dx,) return (dx,)
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
_geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule)
def layernorm_geglu_fp8_mlp(x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
kernels: List[jnp.ndarray],
fp8_gemm_pkg: FP8MetaPackage,
layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
layernorm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None,
ffn1_ckpt_name: str = 'ffn1',
ffn2_ckpt_name: str = 'ffn2') -> jnp.ndarray:
"""
Layernorm + GEMM1 + GeGLU + GEMM2
"""
assert len(kernels) == 2
assert fp8_gemm_pkg.num_of_gemm == len(kernels)
kernel_1 = kernels[0]
kernel_2 = kernels[1]
fp8_max = fp8_gemm_pkg.fp8_max
amax = fp8_gemm_pkg.amax
scale = fp8_gemm_pkg.scale
scale_inv = fp8_gemm_pkg.scale_inv
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm':
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
output = _layernorm_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale,
scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
zero_centered_gamma, epsilon, layernorm_input_axes,
dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name,
ffn2_ckpt_name)
return output
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
def _layernorm_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray,
amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str,
zero_centered_gamma: bool, epsilon: float,
layernorm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
ffn1_ckpt_name: str, ffn2_ckpt_name: str):
output, _ = _layernorm_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax,
scale, scale_inv, fwd_dtype, bwd_dtype,
layernorm_type, zero_centered_gamma, epsilon,
layernorm_input_axes, dot_1_input_axes,
dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name)
return output
def _layernorm_geglu_fp8_mlp_fwd_rule(
x,
gamma,
beta,
kernel_1,
kernel_2,
fp8_max,
amax,
scale,
scale_inv,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
layernorm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name,
ffn2_ckpt_name):
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (Hidden_in, 2, Hidden_out)
# Kernel_2 should be in shape of (Hidden_in, Hidden_out)
assert len(kernel_1.shape) == 3
assert kernel_1.shape[-2] == 2
assert len(kernel_2.shape) == 2
x_contracting_dims = (len(x.shape) - 1,)
xt_batch_dims = tuple(range(1, x.ndim))
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
assert kernel_1.shape[-1] == kernel_2.shape[0]
amax = FP8Helper.update_amax_history(amax)
gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
x_amax = amax[gemm1_x_idx, 0:1]
x_scale = scale[gemm1_x_idx]
x_scale_inv = scale_inv[gemm1_x_idx]
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
if layernorm_type == 'layernorm':
ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
x,
gamma,
beta,
x_amax,
x_scale,
x_scale_inv,
out_dtype=fwd_dtype,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
gamma,
x_amax,
x_scale,
x_scale_inv,
out_dtype=fwd_dtype,
epsilon=epsilon)
mu = None
assert x.shape == ln_out.shape
kernel_1_amax = amax[gemm1_kernel_idx, 0:1]
kernel_1_scale = scale[gemm1_kernel_idx]
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel_1, updated_kernel_1_amax = \
cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)
ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)
# (batch..., hidden_in) x (hidden_in, 2, hidden_out)
dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
(x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
geglu_out_amax = amax[gemm2_x_idx, 0:1]
geglu_out_scale = scale[gemm2_x_idx]
geglu_out_scale_inv = scale_inv[gemm2_x_idx]
# (batch..., hidden_in) -> (batch..., hidden)
casted_geglu_out, updated_geglu_amax = gated_gelu_fp8(dot_1_output, geglu_out_amax,
geglu_out_scale, geglu_out_scale_inv,
fwd_dtype)
casted_geglu_out = with_sharding_constraint_by_logical_axes(casted_geglu_out, dot_2_input_axes)
kernel_2_scale = scale[gemm2_kernel_idx]
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)
# (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = fp8_dot_impl(casted_geglu_out, casted_kernel_2, geglu_out_scale_inv,
kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, casted_kernel_1,
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_geglu_amax,
updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims)
return dot_2_output, ctx
def _layernorm_geglu_fp8_mlp_bwd_rule(
fwd_dtype, # pylint: disable=unused-argument
bwd_dtype,
layernorm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name, # pylint: disable=unused-argument
ffn2_ckpt_name, # pylint: disable=unused-argument
ctx,
grad):
x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, \
casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
updated_geglu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
x_contracting_dims, xt_batch_dims = ctx
gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)
grad_amax = amax[gemm2_grad_idx, 0:1]
grad_scale = scale[gemm2_grad_idx]
grad_scale_inv = scale_inv[gemm2_grad_idx]
# Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=-1)
casted_geglu_out_t = transpose(casted_geglu_out,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
# (hidden, batch...,) x (hidden, batch...)
gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
wgrad_2 = fp8_dot_impl(casted_geglu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv,
grad.dtype, (xt_batch_dims, xt_batch_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# (batch..., hidden_out) x (hidden_in, hidden_out)
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv,
grad.dtype, (x_contracting_dims, (1,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)
dgeglu_amax = amax[gemm1_grad_idx, 0:1]
dgeglu_scale = scale[gemm1_grad_idx]
dgeglu_scale_inv = scale_inv[gemm1_grad_idx]
casted_dgeglu, casted_dgeglu_t, updated_dgeglu_amax = dgated_gelu_cast_transpose(
dgrad_2,
dot_1_output,
dgeglu_amax,
dgeglu_scale,
dgeglu_scale_inv,
bwd_dtype,
static_axis_boundary=-1)
ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
# (hidden, batch...) x (2, hidden, batch...)
xt_batch_dims_plus_act_dim = tuple(i + 1 for i in xt_batch_dims)
gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dgeglu_t, gemm1_x_scale_inv, dgeglu_scale_inv,
grad.dtype, (xt_batch_dims, xt_batch_dims_plus_act_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# (batch..., 2, hidden_out) x (hidden_in, 2, hidden_out)
x_contracting_dims_plus_act_dim = (min(x_contracting_dims),) + tuple(
i + 1 for i in x_contracting_dims)
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
dgrad_1 = fp8_dot_impl(casted_dgeglu, casted_kernel_1, dgeglu_scale_inv, kernel_1_scale_inv,
grad.dtype, (x_contracting_dims_plus_act_dim, (1, 2)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)
if layernorm_type == 'layernorm':
dx, dgamma, dbeta = layernorm_bwd(dgrad_1,
x,
mu,
rsigma,
gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0])
amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0])
amax = amax.at[gemm1_grad_idx, 0].set(updated_dgeglu_amax[0])
amax = amax.at[gemm2_x_idx, 0].set(updated_geglu_amax[0])
amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax)
amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
return dx, dgamma, dbeta, wgrad_1, wgrad_2, \
fp8_max, amax, scale, scale_inv
_layernorm_geglu_fp8_mlp.defvjp(_layernorm_geglu_fp8_mlp_fwd_rule,
_layernorm_geglu_fp8_mlp_bwd_rule)
def fused_layernorm_fp8_mlp(x: jnp.ndarray,
def layernorm_gelu_fp8_mlp(x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
beta: jnp.ndarray, beta: jnp.ndarray,
kernels: List[jnp.ndarray], kernels: List[jnp.ndarray],
...@@ -398,9 +85,11 @@ def layernorm_gelu_fp8_mlp(x: jnp.ndarray, ...@@ -398,9 +85,11 @@ def layernorm_gelu_fp8_mlp(x: jnp.ndarray,
dot_1_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None,
ffn1_ckpt_name: str = 'ffn1', ffn1_ckpt_name: str = 'ffn1',
ffn2_ckpt_name: str = 'ffn2') -> jnp.ndarray: ffn2_ckpt_name: str = 'ffn2',
activation_type: Sequence[Union[str, Callable]] = ('gelu',),
use_bias: bool = True) -> jnp.ndarray:
""" """
Layernorm + GEMM1 + bias + GeLU + GEMM2 + bias Layernorm + GEMM1 + bias + activation + GEMM2 + bias
""" """
assert len(kernels) == 2 assert len(kernels) == 2
...@@ -424,32 +113,36 @@ def layernorm_gelu_fp8_mlp(x: jnp.ndarray, ...@@ -424,32 +113,36 @@ def layernorm_gelu_fp8_mlp(x: jnp.ndarray,
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'" "if layernorm_type is 'rmsnorm'"
output = _layernorm_gelu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max, output = _fused_layernorm_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max,
amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type, amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
zero_centered_gamma, epsilon, layernorm_input_axes, zero_centered_gamma, epsilon, layernorm_input_axes,
dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name,
ffn2_ckpt_name) ffn2_ckpt_name, activation_type, use_bias)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20)) @partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22))
def _layernorm_gelu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray, kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray,
bias_2: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray, bias_2: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype, scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool, bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool,
epsilon: float, layernorm_input_axes: Tuple[str, ...], epsilon: float, layernorm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
ffn1_ckpt_name: str, ffn2_ckpt_name: str): ffn1_ckpt_name: str, ffn2_ckpt_name: str,
output, _ = _layernorm_gelu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, activation_type: Sequence[Union[str, Callable]],
fp8_max, amax, scale, scale_inv, fwd_dtype, use_bias: bool):
bwd_dtype, layernorm_type, zero_centered_gamma, output, _ = _fused_layernorm_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1,
epsilon, layernorm_input_axes, dot_1_input_axes, bias_2, fp8_max, amax, scale, scale_inv,
dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name) fwd_dtype, bwd_dtype, layernorm_type,
zero_centered_gamma, epsilon,
layernorm_input_axes, dot_1_input_axes,
dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name,
activation_type, use_bias)
return output return output
def _layernorm_gelu_fp8_mlp_fwd_rule( def _fused_layernorm_fp8_mlp_fwd_rule(
x, x,
gamma, gamma,
beta, beta,
...@@ -470,13 +163,16 @@ def _layernorm_gelu_fp8_mlp_fwd_rule( ...@@ -470,13 +163,16 @@ def _layernorm_gelu_fp8_mlp_fwd_rule(
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name): ffn2_ckpt_name,
activation_type,
use_bias):
is_gated = len(activation_type) > 1
# x should be in shape of (batch..., hidden) # x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out) # Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
# Kernel_2 should be in shape of (Hidden_in, Hidden_out) # Kernel_2 should be in shape of (Hidden_in, Hidden_out)
assert len(kernel_1.shape) == 3 assert len(kernel_1.shape) == 3
assert kernel_1.shape[-2] == 1 assert kernel_1.shape[-2] == len(activation_type)
assert len(kernel_2.shape) == 2 assert len(kernel_2.shape) == 2
x_contracting_dims = (len(x.shape) - 1,) x_contracting_dims = (len(x.shape) - 1,)
...@@ -487,6 +183,7 @@ def _layernorm_gelu_fp8_mlp_fwd_rule( ...@@ -487,6 +183,7 @@ def _layernorm_gelu_fp8_mlp_fwd_rule(
# Squeeze act axis # Squeeze act axis
# (hidden_in, 1, hidden_out) -> (hidden_in, hidden_out) # (hidden_in, 1, hidden_out) -> (hidden_in, hidden_out)
if not is_gated:
kernel_1 = jnp.squeeze(kernel_1, axis=-2) kernel_1 = jnp.squeeze(kernel_1, axis=-2)
amax = FP8Helper.update_amax_history(amax) amax = FP8Helper.update_amax_history(amax)
...@@ -539,22 +236,26 @@ def _layernorm_gelu_fp8_mlp_fwd_rule( ...@@ -539,22 +236,26 @@ def _layernorm_gelu_fp8_mlp_fwd_rule(
dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype, dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
(x_contracting_dims, (0,)), (x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
if use_bias:
bias_1_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1.shape bias_1_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1.shape
dot_1_output += jnp.reshape(bias_1, bias_1_shape) dot_1_output += jnp.reshape(bias_1, bias_1_shape)
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1) gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
gelu_out_amax = amax[gemm2_x_idx, 0:1] activation_lu_out_amax = amax[gemm2_x_idx, 0:1]
gelu_out_scale = scale[gemm2_x_idx] activation_lu_out_scale = scale[gemm2_x_idx]
gelu_out_scale_inv = scale_inv[gemm2_x_idx] activation_lu_out_scale_inv = scale_inv[gemm2_x_idx]
activation_lu_fp8 = activation_fp8_dict[activation_type]["fwd"]
# (batch..., hidden_in) -> (batch..., hidden) # (batch..., hidden_in) -> (batch..., hidden)
casted_gelu_out, updated_gelu_amax = gelu_fp8(dot_1_output, gelu_out_amax, gelu_out_scale, casted_activation_lu_out, updated_activation_lu_amax = activation_lu_fp8(dot_1_output,
gelu_out_scale_inv, fwd_dtype) activation_lu_out_amax, activation_lu_out_scale,
activation_lu_out_scale_inv, fwd_dtype)
casted_gelu_out = with_sharding_constraint_by_logical_axes(casted_gelu_out, dot_2_input_axes) casted_activation_lu_out = with_sharding_constraint_by_logical_axes(casted_activation_lu_out,
dot_2_input_axes)
kernel_2_scale = scale[gemm2_kernel_idx] kernel_2_scale = scale[gemm2_kernel_idx]
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx] kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
...@@ -563,23 +264,26 @@ def _layernorm_gelu_fp8_mlp_fwd_rule( ...@@ -563,23 +264,26 @@ def _layernorm_gelu_fp8_mlp_fwd_rule(
casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale) casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)
# (batch..., hidden_in) x (hidden_out, hidden_in) # (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = fp8_dot_impl(casted_gelu_out, casted_kernel_2, gelu_out_scale_inv, dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2,
activation_lu_out_scale_inv,
kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)), kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
if use_bias:
bias_2_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2.shape bias_2_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2.shape
dot_2_output += jnp.reshape(bias_2, bias_2_shape) dot_2_output += jnp.reshape(bias_2, bias_2_shape)
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_gelu_out, casted_kernel_1, ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1,
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_gelu_amax, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims, updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax,
bias_1.shape, bias_2.shape) x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape)
return dot_2_output, ctx return dot_2_output, ctx
def _layernorm_gelu_fp8_mlp_bwd_rule( def _fused_layernorm_fp8_mlp_bwd_rule(
fwd_dtype, # pylint: disable=unused-argument fwd_dtype, # pylint: disable=unused-argument
bwd_dtype, bwd_dtype,
layernorm_type, layernorm_type,
...@@ -590,13 +294,17 @@ def _layernorm_gelu_fp8_mlp_bwd_rule( ...@@ -590,13 +294,17 @@ def _layernorm_gelu_fp8_mlp_bwd_rule(
dot_2_input_axes, dot_2_input_axes,
ffn1_ckpt_name, # pylint: disable=unused-argument ffn1_ckpt_name, # pylint: disable=unused-argument
ffn2_ckpt_name, # pylint: disable=unused-argument ffn2_ckpt_name, # pylint: disable=unused-argument
activation_type,
use_bias,
ctx, ctx,
grad): grad):
x, ln_out, mu, rsigma, gamma, dot_1_output, casted_gelu_out, \ x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, \
casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \ casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
updated_gelu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \ updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape= ctx x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape= ctx
is_gated = len(activation_type) > 1
gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1) gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)
grad_amax = amax[gemm2_grad_idx, 0:1] grad_amax = amax[gemm2_grad_idx, 0:1]
...@@ -606,21 +314,29 @@ def _layernorm_gelu_fp8_mlp_bwd_rule( ...@@ -606,21 +314,29 @@ def _layernorm_gelu_fp8_mlp_bwd_rule(
# Since the sharding of outputs should be the same as dot_1's input # Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
if use_bias:
casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \
dbias_cast_transpose(grad, grad_amax, grad_scale,
grad_scale_inv, bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
dbias_2 = jnp.reshape(dbias_2, bias_2_shape)
else:
casted_grad, casted_grad_t, updated_grad_amax = \ casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype, cast_transpose(grad, grad_amax, grad_scale,
static_axis_boundary=-1, transpose_axis_boundary=-1) grad_scale_inv, bwd_dtype,
casted_gelu_out_t = transpose(casted_gelu_out,
static_axis_boundary=-1, static_axis_boundary=-1,
transpose_axis_boundary=-1) transpose_axis_boundary=-1)
dbias_2 = jnp.empty(bias_2_shape, grad.dtype)
dbias_2 = jnp.sum(grad, axis=(i for i in range(grad.ndim - 1))) casted_activation_lu_out_t = transpose(casted_activation_lu_out,
dbias_2 = jnp.reshape(dbias_2, bias_2_shape) static_axis_boundary=-1,
transpose_axis_boundary=-1)
# (hidden, batch...,) x (hidden, batch...) # (hidden, batch...,) x (hidden, batch...)
gemm2_x_scale_inv = scale_inv[gemm2_x_idx] gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
wgrad_2 = fp8_dot_impl(casted_gelu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv, wgrad_2 = fp8_dot_impl(casted_activation_lu_out_t, casted_grad_t, gemm2_x_scale_inv,
grad.dtype, (xt_batch_dims, xt_batch_dims), grad_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# (batch..., hidden_out) x (hidden_in, hidden_out) # (batch..., hidden_out) x (hidden_in, hidden_out)
...@@ -633,36 +349,85 @@ def _layernorm_gelu_fp8_mlp_bwd_rule( ...@@ -633,36 +349,85 @@ def _layernorm_gelu_fp8_mlp_bwd_rule(
gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0) gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)
dgelu_amax = amax[gemm1_grad_idx, 0:1] dactivation_lu_amax = amax[gemm1_grad_idx, 0:1]
dgelu_scale = scale[gemm1_grad_idx] dactivation_lu_scale = scale[gemm1_grad_idx]
dgelu_scale_inv = scale_inv[gemm1_grad_idx] dactivation_lu_scale_inv = scale_inv[gemm1_grad_idx]
casted_dgelu, casted_dgelu_t, dbias_1, updated_dgelu_amax = dgelu_dbias_cast_transpose( dactivation_lu_cast_transpose = activation_fp8_dict[activation_type]["bwd"]
dactivation_lu = activation_dict[activation_type]["bwd"](dgrad_2, dot_1_output)
if is_gated:
if use_bias:
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
dbias_cast_transpose(
dactivation_lu,
dactivation_lu_amax,
dactivation_lu_scale,
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-2)
dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else:
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
dactivation_lu_cast_transpose(
dgrad_2,
dot_1_output,
dactivation_lu_amax,
dactivation_lu_scale,
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1)
dbias_1 = jnp.empty(bias_1_shape, bwd_dtype)
else:
if use_bias:
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
dactivation_lu_cast_transpose(
dgrad_2, dgrad_2,
dot_1_output, dot_1_output,
dgelu_amax, dactivation_lu_amax,
dgelu_scale, dactivation_lu_scale,
dgelu_scale_inv, dactivation_lu_scale_inv,
bwd_dtype, bwd_dtype,
static_axis_boundary=-1, static_axis_boundary=-1,
transpose_axis_boundary=-1) transpose_axis_boundary=-1)
dbias_1 = jnp.reshape(dbias_1, bias_1_shape) dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else:
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
cast_transpose(
dactivation_lu,
dactivation_lu_amax,
dactivation_lu_scale,
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
dbias_1 = jnp.empty(bias_1_shape, bwd_dtype)
ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1) ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
# (hidden, batch...) x (hidden, batch...) # (hidden, batch...) x (hidden, batch...)
gemm1_x_scale_inv = scale_inv[gemm1_x_idx] gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dgelu_t, gemm1_x_scale_inv, dgelu_scale_inv, grad.dtype, xt_batch_dims_2 = xt_batch_dims if not is_gated \
(xt_batch_dims, xt_batch_dims), else tuple(i + 1 for i in xt_batch_dims)
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv,
dactivation_lu_scale_inv, grad.dtype,
(xt_batch_dims, xt_batch_dims_2),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# Expand act axis to match the shape with the given kernel_1 # Expand act axis to match the shape with the given kernel_1
if not is_gated:
wgrad_1 = jnp.expand_dims(wgrad_1, axis=-2) wgrad_1 = jnp.expand_dims(wgrad_1, axis=-2)
# (batch..., hidden_out) x (hidden_in, hidden_out) # (batch..., hidden_out) x (hidden_in, hidden_out)
if is_gated:
x_contracting_dims = ((min(x_contracting_dims),) + tuple(
i + 1 for i in x_contracting_dims), (1,2))
else:
x_contracting_dims = (x_contracting_dims, (1,))
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx] kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
dgrad_1 = fp8_dot_impl(casted_dgelu, casted_kernel_1, dgelu_scale_inv, kernel_1_scale_inv, dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1,
grad.dtype, (x_contracting_dims, (1,)), dactivation_lu_scale_inv, kernel_1_scale_inv,
grad.dtype, x_contracting_dims,
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)
...@@ -683,15 +448,15 @@ def _layernorm_gelu_fp8_mlp_bwd_rule( ...@@ -683,15 +448,15 @@ def _layernorm_gelu_fp8_mlp_bwd_rule(
amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0]) amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0])
amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0]) amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0])
amax = amax.at[gemm1_grad_idx, 0].set(updated_dgelu_amax[0]) amax = amax.at[gemm1_grad_idx, 0].set(updated_dactivation_lu_amax[0])
amax = amax.at[gemm2_x_idx, 0].set(updated_gelu_amax[0]) amax = amax.at[gemm2_x_idx, 0].set(updated_activation_lu_amax[0])
amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax) amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax)
amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0]) amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \ return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \
fp8_max, amax, scale, scale_inv fp8_max, amax, scale, scale_inv
_layernorm_gelu_fp8_mlp.defvjp(_layernorm_gelu_fp8_mlp_fwd_rule, _layernorm_gelu_fp8_mlp_bwd_rule) _fused_layernorm_fp8_mlp.defvjp(_fused_layernorm_fp8_mlp_fwd_rule,
_fused_layernorm_fp8_mlp_bwd_rule)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment