Unverified Commit ce163f9e authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Support SP + RoPE + GeLU (#602)



* Adding support of sequence parallelism
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding RoPE
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix wrong batch_logical_axes
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Rnaming FSDP outer env var
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Poring RoPE to Praxis layers.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Porting GeLU + [FP8 Cast].
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* WAR to make XLA successfully match FP8 GEMM on FFN1 with GeLU.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Allowing arbitrary dimension of NVShape for the workspace allocation
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding checkpoint_name to fused functions of mlp.py to get better perf with nn.scan.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Modify with review feedback.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix bugs
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix typo.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed for lint
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Follow review feedback to modify code.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix typo.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Port SP to Praxis
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix an issue when enabling both GQA and RoPE.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Update docs
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
parent 29b0c9ca
......@@ -6,10 +6,23 @@
Jax
=======
.. autoapiclass:: transformer_engine.jax.MajorShardingType
.. autoapiclass:: transformer_engine.jax.ShardingType
Pre-defined Variable of Logical Axes
------------------------------------
Variables are available in `transformer_engine.jax.sharding`.
* BATCH_AXES: The logical axis of batch dimension. It is usually sharded along DP + FSDP on Mesh.
* SEQLEN_AXES: The logical axis of sequence length dimension. It is usually not sharded.
* SEQLEN_TP_AXES: The logical axis of sequence length dimension. It is usually sharded along TP on Mesh.
* HEAD_AXES: The logical axis of head dimension of MHA. It is usually sharded along TP on Mesh.
* HIDDEN_AXES: The logical axis of hidden dimension. It is usually not sharded.
* HIDDEN_TP_AXES: The logical axis of hidden dimension. It is usually sharded along TP on Mesh.
* JOINED_AXES: The logical axis of non-defined dimension. It is usually not sharded.
Modules
------------------------------------
.. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType
.. autoapiclass:: transformer_engine.jax.ShardingResource(dp_resource=None, tp_resource=None)
.. autoapiclass:: transformer_engine.jax.MeshResource()
.. autoapifunction:: transformer_engine.jax.fp8_autocast
......
......@@ -35,6 +35,7 @@ INPUT_KEY = 'input_rng'
class Net(nn.Module):
"""NLP Encoder"""
num_embed: int
enable_seq_paral: bool
@nn.compact
def __call__(self, x, mask, disable_dropout=False):
......@@ -50,11 +51,17 @@ class Net(nn.Module):
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral,
dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device.
x = jax.lax.with_sharding_constraint(x,
jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None))
x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
......@@ -266,7 +273,7 @@ def train_and_evaluate(args):
with te.fp8_autocast(args.use_fp8,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None,
None)):
encoder = Net(num_embed)
encoder = Net(num_embed, args.enable_sp)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
......@@ -379,6 +386,10 @@ def encoder_parser(args):
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration")
parser.add_argument("--enable-sp",
action="store_true",
default=False,
help="Enable sequence parallelism.")
return parser.parse_args(args)
......@@ -405,6 +416,20 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
def test_te_bf16_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8_sp(self):
"""Test Transformer Engine with FP8 + SP"""
self.args.enable_sp = True
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
......@@ -14,6 +14,8 @@ from jax import jit, value_and_grad
from flax import linen as nn
from utils import assert_allclose
from transformer_engine.jax.cpp_extensions import dgelu, dgelu_dbias_cast_transpose
from transformer_engine.jax.cpp_extensions import gelu, gelu_fp8
from transformer_engine.jax.cpp_extensions import dgated_gelu, gated_gelu
from transformer_engine.jax.cpp_extensions import dgated_gelu_cast_transpose, gated_gelu_fp8
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
......@@ -21,6 +23,7 @@ from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp
from transformer_engine.jax.mlp import layernorm_gelu_fp8_mlp
GEMM_CASES = [
(256, 256, 512),
......@@ -285,6 +288,126 @@ class TestFP8Dot:
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)])
def test_grad_ln_gelu_fp8_mlp(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 6)
activations = ('gelu',)
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16)
k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
b1 = jax.random.normal(subkeys[3], (len(activations), n), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
init_fp8_metas_amax = jnp.zeros(
(FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), jnp.float32)
init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
# x is input tensor, matrix 2d
# y, z are weights, matrix 2d
# out = ((x * y) + w) * z + v
fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.mean(
layernorm_gelu_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm"))
def ln_gelu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray,
fp8_maxs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray) -> jnp.ndarray:
x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
ln_out = y * ln_scale
ln_out = jnp.asarray(ln_out, jnp.bfloat16)
fp8_gemm_1_pkg = FP8MetaPackage(1, fp8_maxs[:FP8Helper.NUM_META_PER_GEMM],
amax[:FP8Helper.NUM_META_PER_GEMM],
scale[:FP8Helper.NUM_META_PER_GEMM],
scale_inv[:FP8Helper.NUM_META_PER_GEMM])
linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,)))
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape)
x = jax.nn.gelu(linear_1_out)
x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)
fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
amax[FP8Helper.NUM_META_PER_GEMM:],
scale[FP8Helper.NUM_META_PER_GEMM:],
scale_inv[FP8Helper.NUM_META_PER_GEMM:])
output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,)))
bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
output += jnp.reshape(bias_2, bias_2_shape)
return output
def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
return jnp.mean(
ln_gelu_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv))
value_n_grad_primitive_func = jit(
value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
ref_fp8_max = init_fp8_max
ref_fp8_metas_amax = init_fp8_metas_amax
ref_fp8_metas_scale = init_fp8_metas_scale
ref_fp8_metas_scale_inv = init_fp8_metas_scale_inv
pri_fp8_max = init_fp8_max
pri_fp8_metas_amax = init_fp8_metas_amax
pri_fp8_metas_scale = init_fp8_metas_scale
pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv
for _ in range(3):
ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad,
ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale,
ref_fp8_metas_scale_inv) = value_n_grad_ref_func(
a, s, k1, k2, b1, b2, ref_fp8_max, ref_fp8_metas_amax,
ref_fp8_metas_scale, ref_fp8_metas_scale_inv)
for _ in range(3):
primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
primitive_k2_grad, primitive_b1_grad, primitive_b2_grad, pri_fp8_max,
pri_fp8_metas_amax, pri_fp8_metas_scale,
pri_fp8_metas_scale_inv) = value_n_grad_primitive_func(
a, s, k1, k2, b1, b2, pri_fp8_max, pri_fp8_metas_amax,
pri_fp8_metas_scale, pri_fp8_metas_scale_inv)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
jnp.asarray(ref_a_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=jnp.bfloat16)
assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
jnp.asarray(ref_b2_grad, np.float32),
dtype=jnp.bfloat16)
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
......@@ -294,6 +417,96 @@ def random_inputs_fixture(shape):
return out
class TestGeLu:
def ref_func(self, inputs):
func = jit(value_and_grad(lambda x: jnp.mean(jax.nn.gelu(x))))
return func(inputs)
def prim_func(self, inputs):
@jax.custom_vjp
def primitive(x):
out, _ = primitive_fwd(x)
return out
def primitive_fwd(x):
out = gelu(x)
ctx = x
return out, ctx
def primitive_bwd(ctx, g):
x = ctx
out = dgelu(g, x)
return (out,)
primitive.defvjp(primitive_fwd, primitive_bwd)
func = value_and_grad(lambda x: jnp.mean(primitive(x)))
return func(inputs)
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
def test_gelu(self, random_inputs):
x = random_inputs
prim_out, prim_grad = self.prim_func(x)
ref_out, ref_grad = self.ref_func(x)
assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
class TestGeLuFP8(TestGeLu):
def prim_func(self, inputs):
amax = self.amax
scale = self.scale
scale_inv = self.scale_inv
no_use = jnp.zeros(1, jnp.float32)
@jax.custom_vjp
def primitive(x, y, z, w):
out = primitive_fwd(x)
return out
def primitive_fwd(x, y, z, w):
out, _ = gelu_fp8(x, amax, scale, scale_inv, jnp.float8_e4m3fn)
out = dequantize(out, x.dtype, scale_inv)
ctx = x
return out, ctx
def primitive_bwd(ctx, g):
x = ctx
dgelu, dgelu_trans, dbias, amax_out = dgelu_dbias_cast_transpose(
g, x, amax, scale, scale_inv, jnp.float8_e5m2, -1)
dgelu = dequantize(dgelu, x.dtype, scale_inv)
dgelu_trans = dequantize(dgelu_trans, x.dtype, scale_inv)
return dgelu, dgelu_trans, dbias, amax_out
primitive.defvjp(primitive_fwd, primitive_bwd)
func = value_and_grad(lambda x, y, z, w: jnp.mean(primitive(x, y, z, w)), (0, 1, 2, 3))
return func(inputs, no_use, no_use, no_use)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
def test_gelu(self, random_inputs):
self.amax = jnp.zeros(1, jnp.float32)
self.scale = jnp.ones(1, jnp.float32)
self.scale_inv = jnp.ones(1, jnp.float32)
x = random_inputs
prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
ref_out, ref_grad = self.ref_func(x)
assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(prim_grad_trans,
jnp.transpose(ref_grad, (2, 0, 1)),
dtype=FP8Helper.BWD_DTYPE)
class TestGatedGeLu:
def ref_func(self, inputs):
......
......@@ -88,6 +88,7 @@ _KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence'
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
_KEY_OF_NUM_HEADS = 'num_attention_heads'
_KEY_OF_NUM_GQA_GROUPS = 'num_gqa_groups'
_KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
......@@ -137,7 +138,25 @@ ATTRS = [{
_KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4
_KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu',)),
_KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_ENABLE_ROPE: True
}, {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_ENABLE_ROPE: True
}]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
......
......@@ -818,12 +818,14 @@ class TransformerLayerAttr:
LYR_TYPE = 'layer_type'
ZERO_CEN = 'zero_centered_gamma'
TRANSPOSE_BS = 'transpose_batch_sequence'
ENABLE_ROPE = 'enable_rotary_pos_emb'
ATTRS = [{
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -831,6 +833,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -838,6 +841,7 @@ class TransformerLayerAttr:
ZERO_CEN: True,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -845,6 +849,7 @@ class TransformerLayerAttr:
ZERO_CEN: True,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -852,6 +857,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -859,6 +865,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -866,6 +873,7 @@ class TransformerLayerAttr:
ZERO_CEN: True,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -873,6 +881,7 @@ class TransformerLayerAttr:
ZERO_CEN: True,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -880,6 +889,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -887,6 +897,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -894,6 +905,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -901,6 +913,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -908,6 +921,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -915,6 +929,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -922,6 +937,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -929,6 +945,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -936,6 +953,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -943,6 +961,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -950,6 +969,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -957,6 +977,23 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: True,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: True,
TRANSPOSE_BS: False
}]
......@@ -984,6 +1021,7 @@ class TestTransformer(TestLayer):
use_bias = attrs[TransformerLayerAttr.USE_BIAS]
bias_init = WeightInit.Constant(0.0)
layer_type = attrs[TransformerLayerAttr.LYR_TYPE]
enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
enable_relative_embedding = True
relative_embedding = pax_fiddle.Config(RelativePositionBiases,
num_attention_heads=num_attention_heads)
......@@ -1019,6 +1057,7 @@ class TestTransformer(TestLayer):
bias_init=bias_init,
layer_type=layer_type,
enable_relative_embedding=enable_relative_embedding,
enable_rotary_pos_emb=enable_rotary_pos_emb,
relative_embedding=relative_embedding,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence)
......@@ -1040,6 +1079,7 @@ class TestTransformer(TestLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init(
"bias", bias_init),
layer_type=layer_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
enable_relative_embedding=enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
drop_path=drop_path,
......
......@@ -340,6 +340,29 @@ class MlpBlock(nn.Module):
return output
def apply_rotary_pos_emb(
inputs: jnp.ndarray,
position: jnp.ndarray,
min_timescale: int = 1,
max_timescale: int = 10000,
):
embedding_dim = inputs.shape[-1]
half_embedding_dim = embedding_dim // 2
fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim
timescale = min_timescale * (max_timescale / min_timescale)**fraction
timescale = jnp.expand_dims(timescale, axis=tuple(range(inputs.ndim - 1)))
position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
sinusoid_inp = position / timescale
sin = jnp.sin(sinusoid_inp)
cos = jnp.cos(sinusoid_inp)
first_half, second_half = jnp.split(inputs, 2, axis=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
first_part = first_part.astype(inputs.dtype)
second_part = second_part.astype(inputs.dtype)
return jnp.concatenate([first_part, second_part], axis=-1)
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
......@@ -368,6 +391,7 @@ class MultiHeadAttention(nn.Module):
float32_logits: bool = False # computes logits in float32 for stability.
scale_attn_logits: bool = False
scaled_query_init: bool = True
enable_rotary_pos_emb: bool = False
fuse_qkv: bool = True
def __post_init__(self):
......@@ -482,6 +506,15 @@ class MultiHeadAttention(nn.Module):
key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
if self.enable_rotary_pos_emb:
batch_dim = 1 if self.transpose_batch_sequence else 0
seq_dim = 1 - batch_dim
position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
query = apply_rotary_pos_emb(query, position)
key = apply_rotary_pos_emb(key, position)
query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
......@@ -802,6 +835,7 @@ class EncoderLayer(nn.Module):
zero_centered_gamma: bool = False
output_layernorm: bool = False
drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False
......@@ -854,6 +888,7 @@ class EncoderLayer(nn.Module):
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
fuse_qkv=self.fuse_qkv_params,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
name='attention')(x,
x,
encoder_mask,
......@@ -922,6 +957,7 @@ class DecoderLayer(nn.Module):
layernorm_type: str = 'layernorm'
zero_centered_gamma: bool = False
drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False
......@@ -981,6 +1017,7 @@ class DecoderLayer(nn.Module):
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
fuse_qkv=self.fuse_qkv_params,
name='self_attention')(x,
x,
......@@ -1014,6 +1051,7 @@ class DecoderLayer(nn.Module):
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
fuse_qkv=self.fuse_qkv_params,
name='encoder_decoder_attention')(y,
encoded,
......
......@@ -1349,12 +1349,6 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
Tensor *dbias,
Tensor *workspace,
cudaStream_t stream) {
CheckInputTensor(input, "cast_transpose_dbias_dgelu_input");
CheckInputTensor(gelu_input, "gelu_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2,
......@@ -1396,6 +1390,12 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
return;
}
CheckInputTensor(input, "cast_transpose_dbias_dgelu_input");
CheckInputTensor(gelu_input, "gelu_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
......
This diff is collapsed.
......@@ -25,6 +25,10 @@ pybind11::dict Registrations() {
pybind11::dict dict;
dict["te_transpose"] = EncapsulateFunction(Transpose);
dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose);
dict["te_gelu"] = EncapsulateFunction(Gelu);
dict["te_gelu_fp8"] = EncapsulateFunction(GeluFP8);
dict["te_dgelu"] = EncapsulateFunction(DGelu);
dict["te_dgelu_dbias_cast_transpose"] = EncapsulateFunction(DGeluDBiasCastTranspose);
dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu);
dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8);
dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu);
......@@ -55,6 +59,7 @@ pybind11::dict Registrations() {
PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("registrations", &Registrations);
m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor);
m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
......@@ -62,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes);
m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes);
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_self_fused_attn_fwd_workspace_sizes", &GetSelfFusedAttnForwardWorkspaceSizes);
......
This diff is collapsed.
......@@ -52,6 +52,18 @@ struct CustomCallCommonDescriptor {
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype);
struct CustomCallCommonWkDescriptor {
Shape shape;
Shape wkshape;
DType in_dtype;
DType out_dtype;
DType wk_dtype;
};
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype,
DType out_dtype, DType wk_dtype);
struct CustomCallNormDescriptor {
size_t batch_size;
size_t hidden_size;
......@@ -73,10 +85,10 @@ struct CustomCallNormDescriptor {
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
DType x_dtype, DType w_dtype,
DType wkspace_dtype, DType barrier_dtype,
DType dgamma_part_dtype, DType dbeta_part_dtype,
bool zero_centered_gamma, float eps, int sm_margin);
DType x_dtype, DType w_dtype, DType wkspace_dtype,
DType barrier_dtype, DType dgamma_part_dtype,
DType dbeta_part_dtype, bool zero_centered_gamma,
float eps, int sm_margin);
struct SoftmaxDescriptor {
size_t batch_size;
......@@ -110,11 +122,10 @@ struct CustomCallFusedAttnDescriptor {
};
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t wkspace_size,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
DType dtype, DType wkspace_dtype, bool is_training);
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
DType wkspace_dtype, bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......@@ -127,6 +138,18 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......@@ -136,20 +159,20 @@ void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t
void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetLayerNormForwardWorkspaceSizes(
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma, float eps
);
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma,
float eps);
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm,
bool zero_centered_gamma, float eps
);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype,
bool is_layer_norm, bool zero_centered_gamma,
float eps);
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......@@ -182,39 +205,33 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
std::size_t opaque_len);
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
bool is_training);
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
bool is_training);
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
......
......@@ -23,8 +23,10 @@ from ..fp8 import FP8Helper, FP8MetaPackage
from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot
from ..mlp import layernorm_geglu_fp8_mlp, geglu
from ..mlp import layernorm_gelu_fp8_mlp, gelu
from ..softmax import is_softmax_kernel_available
from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes
PRNGKey = Any
Shape = Tuple[int, ...]
......@@ -502,6 +504,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
If set False, return None as the second tensor in outputs.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
dot_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
Optimization parameters
-----------------------
......@@ -534,6 +544,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None
depth_scaling: float = None
sharding_type = None
......@@ -571,6 +583,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
) and not self.return_layernorm_output and self.enable_layernorm
if self.enable_layernorm:
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
assert self.axis == -1 # Only support axis = =-1 at this moment
features = inputs.shape[-1]
......@@ -626,8 +640,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
fp8_meta_package,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon)
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_input_axes)
else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
z = type_safe_dot_general(y,
kernel,
fp8_meta_pkg=fp8_meta_package,
......@@ -730,6 +747,18 @@ class LayerNormMLP(TransformerEngineBase):
Dimensions that will share the same dropout mask for hidden
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
dot_1_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 1st dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
dot_2_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 2nd dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
Optimization parameters
-----------------------
......@@ -765,6 +794,9 @@ class LayerNormMLP(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_1_input_axes: Tuple[str, ...] = None
dot_2_input_axes: Tuple[str, ...] = None
major_sharding_type = None
def __post_init__(self):
......@@ -812,13 +844,28 @@ class LayerNormMLP(TransformerEngineBase):
normalize_acts.append(act.lower())
return tuple(normalize_acts) in geglu_act_pool
use_fused_ln_mlp = fuse_layernorm \
def is_gelu(acts):
geglu_act_pool = [('gelu',)]
normalize_acts = []
for act in acts:
if not isinstance(act, str):
return False
normalize_acts.append(act.lower())
return tuple(normalize_acts) in geglu_act_pool
use_fused_ln_geglu_mlp = fuse_layernorm \
and (not self.use_bias) and is_geglu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3)
use_fused_ln_gelu_mlp = fuse_layernorm \
and self.use_bias and is_gelu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3)
# LayerNorm
if self.enable_layernorm:
assert self.axis == -1 # Only support axis == -1 at this moment
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
features = inputs.shape[-1]
......@@ -883,7 +930,10 @@ class LayerNormMLP(TransformerEngineBase):
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
contract_ind = tuple(range(0, len(axis)))
if use_fused_ln_mlp:
ffn1_ckpt_name = 'ffn1'
ffn2_ckpt_name = 'ffn2'
if use_fused_ln_geglu_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
out = layernorm_geglu_fp8_mlp(y,
......@@ -892,8 +942,41 @@ class LayerNormMLP(TransformerEngineBase):
fp8_meta_package,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon)
else: # not use_fused_ln_mlp
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name)
elif use_fused_ln_gelu_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
bias_1 = nn_partitioning.param_with_axes('wi_bias',
self.bias_init,
intermediate_dim,
jnp.float32,
axes=self.bias_axes_1)
bias_1 = bias_1.astype(self.dtype)
bias_2 = nn_partitioning.param_with_axes('wo_bias',
self.bias_init, (hidden_size,),
jnp.float32,
axes=self.bias_axes_2)
bias_2 = bias_2.astype(self.dtype)
out = layernorm_gelu_fp8_mlp(y,
scale,
ln_bias, [kernel_1, kernel_2], [bias_1, bias_2],
fp8_meta_package,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name)
else: # not use_fused_ln_geglu_mlp
# DenseGeneral 1
gemm1_fp8_meta_package = None if fp8_meta_package is None \
......@@ -906,8 +989,11 @@ class LayerNormMLP(TransformerEngineBase):
gemm1_fp8_meta_package,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon)
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_1_input_axes)
else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
x = type_safe_dot_general(y,
kernel_1,
fp8_meta_pkg=gemm1_fp8_meta_package,
......@@ -924,11 +1010,14 @@ class LayerNormMLP(TransformerEngineBase):
bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
x += jnp.reshape(bias, bias_shape)
x = checkpoint_name(x, 'ffn1')
x = checkpoint_name(x, ffn1_ckpt_name)
activations = []
if is_geglu(self.activations):
z = geglu(x)
elif is_gelu(self.activations):
z = gelu(x)
z = jnp.reshape(z, (*z.shape[:-2], -1))
else:
x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(self.activations):
......@@ -942,6 +1031,8 @@ class LayerNormMLP(TransformerEngineBase):
rng_collection=self.intermediate_dropout_rng_name)(
z, deterministic=deterministic)
z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
# DenseGeneral 2
gemm2_fp8_meta_package = None if fp8_meta_package is None \
else fp8_meta_package.get_package_by_gemm_idx(1)
......@@ -960,6 +1051,6 @@ class LayerNormMLP(TransformerEngineBase):
bias = bias.astype(self.dtype)
out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,))
out = checkpoint_name(out, 'ffn2')
out = checkpoint_name(out, ffn2_ckpt_name)
return out, ln_output # Output, layner_norm_output
This diff is collapsed.
......@@ -4,6 +4,8 @@
"""JAX layernorm modules"""
from functools import partial
from typing import Tuple
import jax
import jax.numpy as jnp
......@@ -12,6 +14,7 @@ from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd
from .dot import fp8_dot_impl, get_precision_of_fp8_dot
from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes
def canonicalize_layernorm_type(x):
......@@ -96,14 +99,20 @@ def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz):
_layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule)
def layernorm_fp8_dot(x: jnp.ndarray,
def layernorm_fp8_dot(
x: jnp.ndarray,
kernel: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
fp8_meta_pkg: FP8MetaPackage,
layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6) -> jnp.ndarray:
epsilon: float = 1e-6,
layernorm_input_axes: Tuple[
str, ...] = None, # The logic axes of sharding constraint to the layernorm input.
dot_input_axes: Tuple[str,
...] = None # The logic axes of sharding constraint to the dot input.
) -> jnp.ndarray:
"""
Layernorm + FP8 GEMM
"""
......@@ -114,18 +123,21 @@ def layernorm_fp8_dot(x: jnp.ndarray,
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
output = _layernorm_fp8_dot(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
layernorm_type, fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon)
layernorm_type, fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon,
layernorm_input_axes, dot_input_axes)
return output
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12))
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14))
def _layernorm_fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
fp8_max: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, layernorm_type: str, fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype, zero_centered_gamma: bool, epsilon: float):
bwd_dtype: jnp.dtype, zero_centered_gamma: bool, epsilon: float,
layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...]):
output, _ = _layernorm_fp8_dot_fwd_rule(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
layernorm_type, fwd_dtype, bwd_dtype,
zero_centered_gamma, epsilon)
zero_centered_gamma, epsilon, layernorm_input_axes,
dot_input_axes)
return output
......@@ -142,7 +154,9 @@ def _layernorm_fp8_dot_fwd_rule(
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
zero_centered_gamma,
epsilon):
epsilon,
layernorm_input_axes,
dot_input_axes):
x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,)
......@@ -156,6 +170,8 @@ def _layernorm_fp8_dot_fwd_rule(
x_scale = scale[gemm_x_idx]
x_scale_inv = scale_inv[gemm_x_idx]
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
if layernorm_type == 'layernorm':
ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
x,
......@@ -191,6 +207,8 @@ def _layernorm_fp8_dot_fwd_rule(
casted_kernel, updated_kernel_amax = \
cast_fp8(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype)
ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_input_axes)
# (batch..., hidden_in) x (hidden_in, hidden_out...)
output = fp8_dot_impl(ln_out, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype,
(x_contracting_dims, k_contracting_dims),
......@@ -209,6 +227,8 @@ def _layernorm_fp8_dot_bwd_rule(
bwd_dtype,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument
ctx,
grad):
ln_out_, casted_kernel, fp8_max, amax, scale, scale_inv, \
......@@ -243,6 +263,7 @@ def _layernorm_fp8_dot_bwd_rule(
(g_for_dgrad_constracting_dim, k_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
if layernorm_type == 'layernorm':
dx, dgamma, dbeta = layernorm_bwd(dgrad,
x,
......
This diff is collapsed.
......@@ -77,6 +77,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
attn_mask_type: str = 'causal'
fuse_qkv: bool = True
transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
float32_logits: bool = False
......@@ -109,6 +110,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
attn_mask_type=self.attn_mask_type,
fuse_qkv=self.fuse_qkv,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
float32_logits=self.float32_logits)
......@@ -156,11 +158,14 @@ class TransformerLayer(TransformerEngineBaseLayer):
float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = 'causal'
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
enable_relative_embedding: bool = True
relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
drop_path: float = 0.0
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = False
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
......@@ -221,11 +226,14 @@ class TransformerLayer(TransformerEngineBaseLayer):
float32_attention_logits=self.float32_attention_logits,
layer_type=self.layer_type,
self_attn_mask_type=self.self_attn_mask_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
enable_relative_embedding=self.enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
drop_path=self.drop_path,
fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init)
......
......@@ -4,7 +4,7 @@
"""
Sharding Meta for xmap with CustomCall
"""
import os
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
......@@ -16,6 +16,19 @@ from jax.sharding import PartitionSpec
_PXLA_THREAD_RESOURCES = pxla.thread_resources
# Axis Names
BATCH_AXES = 'nvte_batch'
SEQLEN_AXES = 'nvte_seqlen'
SEQLEN_TP_AXES = 'nvte_seqlen_tp'
HEAD_AXES = 'nvte_head'
HIDDEN_AXES = 'nvte_hidden'
HIDDEN_TP_AXES = 'nvte_hidden_tp'
JOINED_AXES = 'nvte_joined'
W_NO_SHARD_AXES = 'nvte_w_no_shard'
W_FSDP_AXES = 'nvte_w_fsdp'
W_TP_AXES = 'nvte_w_tp'
W_JOINED_AXES = 'nvte_w_joined'
def _get_mesh_info(resource: str):
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
......@@ -24,6 +37,81 @@ def _get_mesh_info(resource: str):
return mesh.shape[resource], resource
def get_sharding_map_logic_axis_to_mesh_axis():
"""
Generate a dict to map logical axes to mesh axes.
"""
gsr = global_mesh_resource()
IS_FSDP_OUTER = bool(int(os.environ.get("NVTE_OUTER_BATCH_FSDP_DIM", False)))
batch_resources = [gsr.fsdp_resource, gsr.dp_resource] if IS_FSDP_OUTER \
else [gsr.dp_resource, gsr.fsdp_resource]
batch_dim_rule = []
for resource in batch_resources:
if resource is not None and resource not in batch_dim_rule:
batch_dim_rule.append(resource)
if len(batch_dim_rule) <= 0:
batch_dim_rule = None
elif len(batch_dim_rule) == 1:
batch_dim_rule = batch_dim_rule[0]
else:
batch_dim_rule = tuple(batch_dim_rule)
te_logical_axis_to_mesh_axis = {
BATCH_AXES: batch_dim_rule,
SEQLEN_AXES: None,
SEQLEN_TP_AXES: gsr.tp_resource,
HEAD_AXES: gsr.tp_resource,
HIDDEN_AXES: None,
HIDDEN_TP_AXES: gsr.tp_resource,
JOINED_AXES: None,
W_NO_SHARD_AXES: None,
W_FSDP_AXES: gsr.fsdp_resource,
W_TP_AXES: gsr.tp_resource,
W_JOINED_AXES: None,
}
return te_logical_axis_to_mesh_axis
def generate_pspec(logical_axis_names):
"""
Convert logical axes to PartitionSpec
"""
rules = get_sharding_map_logic_axis_to_mesh_axis()
mesh_axis_names = [rules[name] for name in logical_axis_names]
pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
return pspec
def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
"""
A wrapper function to jax.lax.with_sharding_constraint to
support the case that Mesh is empty.
"""
if pspec is None:
return x
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh.empty:
return x
return jax.lax.with_sharding_constraint(x, pspec)
def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: tuple | list):
"""
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
"""
if logical_axis_names is None:
return x
assert len(x.shape) == len(logical_axis_names)
pspec = generate_pspec(logical_axis_names)
return with_sharding_constraint(x, pspec)
def get_all_mesh_axes():
"""
Get all name of mesh axes
......@@ -42,17 +130,6 @@ def get_padded_spec(spec, ndim):
return spec + (None,) * (ndim - len(spec))
def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
"""
A wrapper function to jax.lax.with_sharding_constraint to
support the case that Mesh is empty.
"""
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh.empty:
return x
return jax.lax.with_sharding_constraint(x, pspec)
def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str):
"""
A wrapper function to invoke lax.p* operations, like psum.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment