Unverified Commit ce163f9e authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Support SP + RoPE + GeLU (#602)



* Adding support of sequence parallelism
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding RoPE
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix wrong batch_logical_axes
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Rnaming FSDP outer env var
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Poring RoPE to Praxis layers.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Porting GeLU + [FP8 Cast].
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* WAR to make XLA successfully match FP8 GEMM on FFN1 with GeLU.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Allowing arbitrary dimension of NVShape for the workspace allocation
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding checkpoint_name to fused functions of mlp.py to get better perf with nn.scan.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Modify with review feedback.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix bugs
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix typo.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed for lint
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Follow review feedback to modify code.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix typo.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Port SP to Praxis
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix an issue when enabling both GQA and RoPE.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Update docs
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
parent 29b0c9ca
...@@ -6,10 +6,23 @@ ...@@ -6,10 +6,23 @@
Jax Jax
======= =======
.. autoapiclass:: transformer_engine.jax.MajorShardingType Pre-defined Variable of Logical Axes
.. autoapiclass:: transformer_engine.jax.ShardingType ------------------------------------
Variables are available in `transformer_engine.jax.sharding`.
* BATCH_AXES: The logical axis of batch dimension. It is usually sharded along DP + FSDP on Mesh.
* SEQLEN_AXES: The logical axis of sequence length dimension. It is usually not sharded.
* SEQLEN_TP_AXES: The logical axis of sequence length dimension. It is usually sharded along TP on Mesh.
* HEAD_AXES: The logical axis of head dimension of MHA. It is usually sharded along TP on Mesh.
* HIDDEN_AXES: The logical axis of hidden dimension. It is usually not sharded.
* HIDDEN_TP_AXES: The logical axis of hidden dimension. It is usually sharded along TP on Mesh.
* JOINED_AXES: The logical axis of non-defined dimension. It is usually not sharded.
Modules
------------------------------------
.. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType .. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType
.. autoapiclass:: transformer_engine.jax.ShardingResource(dp_resource=None, tp_resource=None) .. autoapiclass:: transformer_engine.jax.MeshResource()
.. autoapifunction:: transformer_engine.jax.fp8_autocast .. autoapifunction:: transformer_engine.jax.fp8_autocast
......
...@@ -35,6 +35,7 @@ INPUT_KEY = 'input_rng' ...@@ -35,6 +35,7 @@ INPUT_KEY = 'input_rng'
class Net(nn.Module): class Net(nn.Module):
"""NLP Encoder""" """NLP Encoder"""
num_embed: int num_embed: int
enable_seq_paral: bool
@nn.compact @nn.compact
def __call__(self, x, mask, disable_dropout=False): def __call__(self, x, mask, disable_dropout=False):
...@@ -50,11 +51,17 @@ class Net(nn.Module): ...@@ -50,11 +51,17 @@ class Net(nn.Module):
layer_type=te_flax.TransformerLayerType.ENCODER, layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type='padding', self_attn_mask_type='padding',
enable_relative_embedding=False, enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral,
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device.
x = jax.lax.with_sharding_constraint(x,
jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None))
x = te_flax.DenseGeneral(features=256, x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,), bias_axes=(NAMED_TP_AXIS,),
...@@ -266,7 +273,7 @@ def train_and_evaluate(args): ...@@ -266,7 +273,7 @@ def train_and_evaluate(args):
with te.fp8_autocast(args.use_fp8, with te.fp8_autocast(args.use_fp8,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None,
None)): None)):
encoder = Net(num_embed) encoder = Net(num_embed, args.enable_sp)
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
...@@ -379,6 +386,10 @@ def encoder_parser(args): ...@@ -379,6 +386,10 @@ def encoder_parser(args):
action="store_true", action="store_true",
default=False, default=False,
help="Use FP8 for inference and training without recalibration") help="Use FP8 for inference and training without recalibration")
parser.add_argument("--enable-sp",
action="store_true",
default=False,
help="Enable sequence parallelism.")
return parser.parse_args(args) return parser.parse_args(args)
...@@ -405,6 +416,20 @@ class TestEncoder(unittest.TestCase): ...@@ -405,6 +416,20 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79 assert actual[0] < 0.45 and actual[1] > 0.79
def test_te_bf16_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8_sp(self):
"""Test Transformer Engine with FP8 + SP"""
self.args.enable_sp = True
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
if __name__ == "__main__": if __name__ == "__main__":
train_and_evaluate(encoder_parser(None)) train_and_evaluate(encoder_parser(None))
...@@ -14,6 +14,8 @@ from jax import jit, value_and_grad ...@@ -14,6 +14,8 @@ from jax import jit, value_and_grad
from flax import linen as nn from flax import linen as nn
from utils import assert_allclose from utils import assert_allclose
from transformer_engine.jax.cpp_extensions import dgelu, dgelu_dbias_cast_transpose
from transformer_engine.jax.cpp_extensions import gelu, gelu_fp8
from transformer_engine.jax.cpp_extensions import dgated_gelu, gated_gelu from transformer_engine.jax.cpp_extensions import dgated_gelu, gated_gelu
from transformer_engine.jax.cpp_extensions import dgated_gelu_cast_transpose, gated_gelu_fp8 from transformer_engine.jax.cpp_extensions import dgated_gelu_cast_transpose, gated_gelu_fp8
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
...@@ -21,6 +23,7 @@ from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper ...@@ -21,6 +23,7 @@ from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp
from transformer_engine.jax.mlp import layernorm_gelu_fp8_mlp
GEMM_CASES = [ GEMM_CASES = [
(256, 256, 512), (256, 256, 512),
...@@ -285,6 +288,126 @@ class TestFP8Dot: ...@@ -285,6 +288,126 @@ class TestFP8Dot:
jnp.asarray(ref_s_grad, np.float32), jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE) dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)])
def test_grad_ln_gelu_fp8_mlp(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 6)
activations = ('gelu',)
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16)
k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
b1 = jax.random.normal(subkeys[3], (len(activations), n), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
init_fp8_metas_amax = jnp.zeros(
(FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), jnp.float32)
init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
# x is input tensor, matrix 2d
# y, z are weights, matrix 2d
# out = ((x * y) + w) * z + v
fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.mean(
layernorm_gelu_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm"))
def ln_gelu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray,
fp8_maxs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray) -> jnp.ndarray:
x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
ln_out = y * ln_scale
ln_out = jnp.asarray(ln_out, jnp.bfloat16)
fp8_gemm_1_pkg = FP8MetaPackage(1, fp8_maxs[:FP8Helper.NUM_META_PER_GEMM],
amax[:FP8Helper.NUM_META_PER_GEMM],
scale[:FP8Helper.NUM_META_PER_GEMM],
scale_inv[:FP8Helper.NUM_META_PER_GEMM])
linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,)))
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape)
x = jax.nn.gelu(linear_1_out)
x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)
fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
amax[FP8Helper.NUM_META_PER_GEMM:],
scale[FP8Helper.NUM_META_PER_GEMM:],
scale_inv[FP8Helper.NUM_META_PER_GEMM:])
output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,)))
bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
output += jnp.reshape(bias_2, bias_2_shape)
return output
def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
return jnp.mean(
ln_gelu_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv))
value_n_grad_primitive_func = jit(
value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
ref_fp8_max = init_fp8_max
ref_fp8_metas_amax = init_fp8_metas_amax
ref_fp8_metas_scale = init_fp8_metas_scale
ref_fp8_metas_scale_inv = init_fp8_metas_scale_inv
pri_fp8_max = init_fp8_max
pri_fp8_metas_amax = init_fp8_metas_amax
pri_fp8_metas_scale = init_fp8_metas_scale
pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv
for _ in range(3):
ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad,
ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale,
ref_fp8_metas_scale_inv) = value_n_grad_ref_func(
a, s, k1, k2, b1, b2, ref_fp8_max, ref_fp8_metas_amax,
ref_fp8_metas_scale, ref_fp8_metas_scale_inv)
for _ in range(3):
primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
primitive_k2_grad, primitive_b1_grad, primitive_b2_grad, pri_fp8_max,
pri_fp8_metas_amax, pri_fp8_metas_scale,
pri_fp8_metas_scale_inv) = value_n_grad_primitive_func(
a, s, k1, k2, b1, b2, pri_fp8_max, pri_fp8_metas_amax,
pri_fp8_metas_scale, pri_fp8_metas_scale_inv)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
jnp.asarray(ref_a_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=jnp.bfloat16)
assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
jnp.asarray(ref_b2_grad, np.float32),
dtype=jnp.bfloat16)
@pytest.fixture(name="random_inputs") @pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape): def random_inputs_fixture(shape):
...@@ -294,6 +417,96 @@ def random_inputs_fixture(shape): ...@@ -294,6 +417,96 @@ def random_inputs_fixture(shape):
return out return out
class TestGeLu:
def ref_func(self, inputs):
func = jit(value_and_grad(lambda x: jnp.mean(jax.nn.gelu(x))))
return func(inputs)
def prim_func(self, inputs):
@jax.custom_vjp
def primitive(x):
out, _ = primitive_fwd(x)
return out
def primitive_fwd(x):
out = gelu(x)
ctx = x
return out, ctx
def primitive_bwd(ctx, g):
x = ctx
out = dgelu(g, x)
return (out,)
primitive.defvjp(primitive_fwd, primitive_bwd)
func = value_and_grad(lambda x: jnp.mean(primitive(x)))
return func(inputs)
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
def test_gelu(self, random_inputs):
x = random_inputs
prim_out, prim_grad = self.prim_func(x)
ref_out, ref_grad = self.ref_func(x)
assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
class TestGeLuFP8(TestGeLu):
def prim_func(self, inputs):
amax = self.amax
scale = self.scale
scale_inv = self.scale_inv
no_use = jnp.zeros(1, jnp.float32)
@jax.custom_vjp
def primitive(x, y, z, w):
out = primitive_fwd(x)
return out
def primitive_fwd(x, y, z, w):
out, _ = gelu_fp8(x, amax, scale, scale_inv, jnp.float8_e4m3fn)
out = dequantize(out, x.dtype, scale_inv)
ctx = x
return out, ctx
def primitive_bwd(ctx, g):
x = ctx
dgelu, dgelu_trans, dbias, amax_out = dgelu_dbias_cast_transpose(
g, x, amax, scale, scale_inv, jnp.float8_e5m2, -1)
dgelu = dequantize(dgelu, x.dtype, scale_inv)
dgelu_trans = dequantize(dgelu_trans, x.dtype, scale_inv)
return dgelu, dgelu_trans, dbias, amax_out
primitive.defvjp(primitive_fwd, primitive_bwd)
func = value_and_grad(lambda x, y, z, w: jnp.mean(primitive(x, y, z, w)), (0, 1, 2, 3))
return func(inputs, no_use, no_use, no_use)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
def test_gelu(self, random_inputs):
self.amax = jnp.zeros(1, jnp.float32)
self.scale = jnp.ones(1, jnp.float32)
self.scale_inv = jnp.ones(1, jnp.float32)
x = random_inputs
prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
ref_out, ref_grad = self.ref_func(x)
assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(prim_grad_trans,
jnp.transpose(ref_grad, (2, 0, 1)),
dtype=FP8Helper.BWD_DTYPE)
class TestGatedGeLu: class TestGatedGeLu:
def ref_func(self, inputs): def ref_func(self, inputs):
......
...@@ -88,6 +88,7 @@ _KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence' ...@@ -88,6 +88,7 @@ _KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence'
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits" _KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
_KEY_OF_NUM_HEADS = 'num_attention_heads' _KEY_OF_NUM_HEADS = 'num_attention_heads'
_KEY_OF_NUM_GQA_GROUPS = 'num_gqa_groups' _KEY_OF_NUM_GQA_GROUPS = 'num_gqa_groups'
_KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
BASE_ATTRS = { BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True, _KEY_OF_TRANSPOSE_BS: True,
...@@ -137,7 +138,25 @@ ATTRS = [{ ...@@ -137,7 +138,25 @@ ATTRS = [{
_KEY_OF_FUSE_MLP_WI: True _KEY_OF_FUSE_MLP_WI: True
}, { }, {
_KEY_OF_NUM_HEADS: 8, _KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4 _KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu',)),
_KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_ENABLE_ROPE: True
}, {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_ENABLE_ROPE: True
}] }]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
......
...@@ -818,12 +818,14 @@ class TransformerLayerAttr: ...@@ -818,12 +818,14 @@ class TransformerLayerAttr:
LYR_TYPE = 'layer_type' LYR_TYPE = 'layer_type'
ZERO_CEN = 'zero_centered_gamma' ZERO_CEN = 'zero_centered_gamma'
TRANSPOSE_BS = 'transpose_batch_sequence' TRANSPOSE_BS = 'transpose_batch_sequence'
ENABLE_ROPE = 'enable_rotary_pos_emb'
ATTRS = [{ ATTRS = [{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -831,6 +833,7 @@ class TransformerLayerAttr: ...@@ -831,6 +833,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -838,6 +841,7 @@ class TransformerLayerAttr: ...@@ -838,6 +841,7 @@ class TransformerLayerAttr:
ZERO_CEN: True, ZERO_CEN: True,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -845,6 +849,7 @@ class TransformerLayerAttr: ...@@ -845,6 +849,7 @@ class TransformerLayerAttr:
ZERO_CEN: True, ZERO_CEN: True,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -852,6 +857,7 @@ class TransformerLayerAttr: ...@@ -852,6 +857,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -859,6 +865,7 @@ class TransformerLayerAttr: ...@@ -859,6 +865,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -866,6 +873,7 @@ class TransformerLayerAttr: ...@@ -866,6 +873,7 @@ class TransformerLayerAttr:
ZERO_CEN: True, ZERO_CEN: True,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -873,6 +881,7 @@ class TransformerLayerAttr: ...@@ -873,6 +881,7 @@ class TransformerLayerAttr:
ZERO_CEN: True, ZERO_CEN: True,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -880,6 +889,7 @@ class TransformerLayerAttr: ...@@ -880,6 +889,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -887,6 +897,7 @@ class TransformerLayerAttr: ...@@ -887,6 +897,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -894,6 +905,7 @@ class TransformerLayerAttr: ...@@ -894,6 +905,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -901,6 +913,7 @@ class TransformerLayerAttr: ...@@ -901,6 +913,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -908,6 +921,7 @@ class TransformerLayerAttr: ...@@ -908,6 +921,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -915,6 +929,7 @@ class TransformerLayerAttr: ...@@ -915,6 +929,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -922,6 +937,7 @@ class TransformerLayerAttr: ...@@ -922,6 +937,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -929,6 +945,7 @@ class TransformerLayerAttr: ...@@ -929,6 +945,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -936,6 +953,7 @@ class TransformerLayerAttr: ...@@ -936,6 +953,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -943,6 +961,7 @@ class TransformerLayerAttr: ...@@ -943,6 +961,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -950,6 +969,7 @@ class TransformerLayerAttr: ...@@ -950,6 +969,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -957,6 +977,23 @@ class TransformerLayerAttr: ...@@ -957,6 +977,23 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: True,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: True,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}] }]
...@@ -984,6 +1021,7 @@ class TestTransformer(TestLayer): ...@@ -984,6 +1021,7 @@ class TestTransformer(TestLayer):
use_bias = attrs[TransformerLayerAttr.USE_BIAS] use_bias = attrs[TransformerLayerAttr.USE_BIAS]
bias_init = WeightInit.Constant(0.0) bias_init = WeightInit.Constant(0.0)
layer_type = attrs[TransformerLayerAttr.LYR_TYPE] layer_type = attrs[TransformerLayerAttr.LYR_TYPE]
enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
enable_relative_embedding = True enable_relative_embedding = True
relative_embedding = pax_fiddle.Config(RelativePositionBiases, relative_embedding = pax_fiddle.Config(RelativePositionBiases,
num_attention_heads=num_attention_heads) num_attention_heads=num_attention_heads)
...@@ -1019,6 +1057,7 @@ class TestTransformer(TestLayer): ...@@ -1019,6 +1057,7 @@ class TestTransformer(TestLayer):
bias_init=bias_init, bias_init=bias_init,
layer_type=layer_type, layer_type=layer_type,
enable_relative_embedding=enable_relative_embedding, enable_relative_embedding=enable_relative_embedding,
enable_rotary_pos_emb=enable_rotary_pos_emb,
relative_embedding=relative_embedding, relative_embedding=relative_embedding,
drop_path=drop_path, drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence)
...@@ -1040,6 +1079,7 @@ class TestTransformer(TestLayer): ...@@ -1040,6 +1079,7 @@ class TestTransformer(TestLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init( bias_init=TransformerEngineBaseLayer.generate_params_init(
"bias", bias_init), "bias", bias_init),
layer_type=layer_type, layer_type=layer_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
enable_relative_embedding=enable_relative_embedding, enable_relative_embedding=enable_relative_embedding,
relative_embedding=relative_embedding_flax_module, relative_embedding=relative_embedding_flax_module,
drop_path=drop_path, drop_path=drop_path,
......
...@@ -340,6 +340,29 @@ class MlpBlock(nn.Module): ...@@ -340,6 +340,29 @@ class MlpBlock(nn.Module):
return output return output
def apply_rotary_pos_emb(
inputs: jnp.ndarray,
position: jnp.ndarray,
min_timescale: int = 1,
max_timescale: int = 10000,
):
embedding_dim = inputs.shape[-1]
half_embedding_dim = embedding_dim // 2
fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim
timescale = min_timescale * (max_timescale / min_timescale)**fraction
timescale = jnp.expand_dims(timescale, axis=tuple(range(inputs.ndim - 1)))
position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
sinusoid_inp = position / timescale
sin = jnp.sin(sinusoid_inp)
cos = jnp.cos(sinusoid_inp)
first_half, second_half = jnp.split(inputs, 2, axis=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
first_part = first_part.astype(inputs.dtype)
second_part = second_part.astype(inputs.dtype)
return jnp.concatenate([first_part, second_part], axis=-1)
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
...@@ -368,6 +391,7 @@ class MultiHeadAttention(nn.Module): ...@@ -368,6 +391,7 @@ class MultiHeadAttention(nn.Module):
float32_logits: bool = False # computes logits in float32 for stability. float32_logits: bool = False # computes logits in float32 for stability.
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
enable_rotary_pos_emb: bool = False
fuse_qkv: bool = True fuse_qkv: bool = True
def __post_init__(self): def __post_init__(self):
...@@ -482,6 +506,15 @@ class MultiHeadAttention(nn.Module): ...@@ -482,6 +506,15 @@ class MultiHeadAttention(nn.Module):
key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv) key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv) value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
if self.enable_rotary_pos_emb:
batch_dim = 1 if self.transpose_batch_sequence else 0
seq_dim = 1 - batch_dim
position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
query = apply_rotary_pos_emb(query, position)
key = apply_rotary_pos_emb(key, position)
query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim)) query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim)) key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
...@@ -802,6 +835,7 @@ class EncoderLayer(nn.Module): ...@@ -802,6 +835,7 @@ class EncoderLayer(nn.Module):
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
output_layernorm: bool = False output_layernorm: bool = False
drop_path: float = 0.0 drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False fuse_mlp_wi: bool = False
...@@ -854,6 +888,7 @@ class EncoderLayer(nn.Module): ...@@ -854,6 +888,7 @@ class EncoderLayer(nn.Module):
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init, scaled_query_init=self.scaled_query_init,
fuse_qkv=self.fuse_qkv_params, fuse_qkv=self.fuse_qkv_params,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
name='attention')(x, name='attention')(x,
x, x,
encoder_mask, encoder_mask,
...@@ -922,6 +957,7 @@ class DecoderLayer(nn.Module): ...@@ -922,6 +957,7 @@ class DecoderLayer(nn.Module):
layernorm_type: str = 'layernorm' layernorm_type: str = 'layernorm'
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
drop_path: float = 0.0 drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False fuse_mlp_wi: bool = False
...@@ -981,6 +1017,7 @@ class DecoderLayer(nn.Module): ...@@ -981,6 +1017,7 @@ class DecoderLayer(nn.Module):
float32_logits=self.float32_attention_logits, float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init, scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
fuse_qkv=self.fuse_qkv_params, fuse_qkv=self.fuse_qkv_params,
name='self_attention')(x, name='self_attention')(x,
x, x,
...@@ -1014,6 +1051,7 @@ class DecoderLayer(nn.Module): ...@@ -1014,6 +1051,7 @@ class DecoderLayer(nn.Module):
float32_logits=self.float32_attention_logits, float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init, scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
fuse_qkv=self.fuse_qkv_params, fuse_qkv=self.fuse_qkv_params,
name='encoder_decoder_attention')(y, name='encoder_decoder_attention')(y,
encoded, encoded,
......
...@@ -1349,12 +1349,6 @@ void cast_transpose_dbias_dgelu(const Tensor &input, ...@@ -1349,12 +1349,6 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
Tensor *dbias, Tensor *dbias,
Tensor *workspace, Tensor *workspace,
cudaStream_t stream) { cudaStream_t stream) {
CheckInputTensor(input, "cast_transpose_dbias_dgelu_input");
CheckInputTensor(gelu_input, "gelu_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, NVTE_CHECK(transposed_output->data.shape.size() == 2,
...@@ -1396,6 +1390,12 @@ void cast_transpose_dbias_dgelu(const Tensor &input, ...@@ -1396,6 +1390,12 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
return; return;
} }
CheckInputTensor(input, "cast_transpose_dbias_dgelu_input");
CheckInputTensor(gelu_input, "gelu_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) * const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
......
This diff is collapsed.
...@@ -25,6 +25,10 @@ pybind11::dict Registrations() { ...@@ -25,6 +25,10 @@ pybind11::dict Registrations() {
pybind11::dict dict; pybind11::dict dict;
dict["te_transpose"] = EncapsulateFunction(Transpose); dict["te_transpose"] = EncapsulateFunction(Transpose);
dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose); dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose);
dict["te_gelu"] = EncapsulateFunction(Gelu);
dict["te_gelu_fp8"] = EncapsulateFunction(GeluFP8);
dict["te_dgelu"] = EncapsulateFunction(DGelu);
dict["te_dgelu_dbias_cast_transpose"] = EncapsulateFunction(DGeluDBiasCastTranspose);
dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu); dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu);
dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8); dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8);
dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu); dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu);
...@@ -55,6 +59,7 @@ pybind11::dict Registrations() { ...@@ -55,6 +59,7 @@ pybind11::dict Registrations() {
PYBIND11_MODULE(transformer_engine_jax, m) { PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("registrations", &Registrations); m.def("registrations", &Registrations);
m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor); m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor);
m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
...@@ -62,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -62,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_cuda_version", &GetCudaRuntimeVersion); m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability); m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_cublasLt_version", &cublasLtGetVersion); m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes);
m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes); m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes);
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes); m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_self_fused_attn_fwd_workspace_sizes", &GetSelfFusedAttnForwardWorkspaceSizes); m.def("get_self_fused_attn_fwd_workspace_sizes", &GetSelfFusedAttnForwardWorkspaceSizes);
......
This diff is collapsed.
...@@ -52,6 +52,18 @@ struct CustomCallCommonDescriptor { ...@@ -52,6 +52,18 @@ struct CustomCallCommonDescriptor {
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype, pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype); DType out_dtype);
struct CustomCallCommonWkDescriptor {
Shape shape;
Shape wkshape;
DType in_dtype;
DType out_dtype;
DType wk_dtype;
};
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype,
DType out_dtype, DType wk_dtype);
struct CustomCallNormDescriptor { struct CustomCallNormDescriptor {
size_t batch_size; size_t batch_size;
size_t hidden_size; size_t hidden_size;
...@@ -73,10 +85,10 @@ struct CustomCallNormDescriptor { ...@@ -73,10 +85,10 @@ struct CustomCallNormDescriptor {
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size, pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, size_t barrier_size, size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes, size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
DType x_dtype, DType w_dtype, DType x_dtype, DType w_dtype, DType wkspace_dtype,
DType wkspace_dtype, DType barrier_dtype, DType barrier_dtype, DType dgamma_part_dtype,
DType dgamma_part_dtype, DType dbeta_part_dtype, DType dbeta_part_dtype, bool zero_centered_gamma,
bool zero_centered_gamma, float eps, int sm_margin); float eps, int sm_margin);
struct SoftmaxDescriptor { struct SoftmaxDescriptor {
size_t batch_size; size_t batch_size;
...@@ -110,11 +122,10 @@ struct CustomCallFusedAttnDescriptor { ...@@ -110,11 +122,10 @@ struct CustomCallFusedAttnDescriptor {
}; };
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t wkspace_size, size_t num_gqa_groups, size_t head_dim, size_t wkspace_size, float scaling_factor,
float scaling_factor, float dropout_probability, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType wkspace_dtype, bool is_training);
DType dtype, DType wkspace_dtype, bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -127,6 +138,18 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o ...@@ -127,6 +138,18 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
...@@ -136,20 +159,20 @@ void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t ...@@ -136,20 +159,20 @@ void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t
void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetLayerNormForwardWorkspaceSizes( pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma, float eps bool is_layer_norm, bool zero_centered_gamma,
); float eps);
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes( pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm, DType in_dtype, DType w_dtype,
bool zero_centered_gamma, float eps bool is_layer_norm, bool zero_centered_gamma,
); float eps);
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
...@@ -182,39 +205,33 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, ...@@ -182,39 +205,33 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
std::size_t opaque_len); std::size_t opaque_len);
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes( pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor,
float scaling_factor, float dropout_probability, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training bool is_training);
);
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor,
float scaling_factor, float dropout_probability, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training bool is_training);
);
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes( pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
......
...@@ -23,8 +23,10 @@ from ..fp8 import FP8Helper, FP8MetaPackage ...@@ -23,8 +23,10 @@ from ..fp8 import FP8Helper, FP8MetaPackage
from ..layernorm import canonicalize_layernorm_type from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot from ..layernorm import layernorm, layernorm_fp8_dot
from ..mlp import layernorm_geglu_fp8_mlp, geglu from ..mlp import layernorm_geglu_fp8_mlp, geglu
from ..mlp import layernorm_gelu_fp8_mlp, gelu
from ..softmax import is_softmax_kernel_available from ..softmax import is_softmax_kernel_available
from ..softmax import softmax, SoftmaxType from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
...@@ -502,6 +504,14 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -502,6 +504,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
If set False, return None as the second tensor in outputs. If set False, return None as the second tensor in outputs.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
dot_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -534,6 +544,8 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -534,6 +544,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None
depth_scaling: float = None depth_scaling: float = None
sharding_type = None sharding_type = None
...@@ -571,6 +583,8 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -571,6 +583,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
) and not self.return_layernorm_output and self.enable_layernorm ) and not self.return_layernorm_output and self.enable_layernorm
if self.enable_layernorm: if self.enable_layernorm:
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
assert self.axis == -1 # Only support axis = =-1 at this moment assert self.axis == -1 # Only support axis = =-1 at this moment
features = inputs.shape[-1] features = inputs.shape[-1]
...@@ -626,8 +640,11 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -626,8 +640,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
fp8_meta_package, fp8_meta_package,
self.layernorm_type, self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon) epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_input_axes)
else: else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
z = type_safe_dot_general(y, z = type_safe_dot_general(y,
kernel, kernel,
fp8_meta_pkg=fp8_meta_package, fp8_meta_pkg=fp8_meta_package,
...@@ -730,6 +747,18 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -730,6 +747,18 @@ class LayerNormMLP(TransformerEngineBase):
Dimensions that will share the same dropout mask for hidden Dimensions that will share the same dropout mask for hidden
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
dot_1_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 1st dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
dot_2_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 2nd dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -765,6 +794,9 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -765,6 +794,9 @@ class LayerNormMLP(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_1_input_axes: Tuple[str, ...] = None
dot_2_input_axes: Tuple[str, ...] = None
major_sharding_type = None major_sharding_type = None
def __post_init__(self): def __post_init__(self):
...@@ -812,13 +844,28 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -812,13 +844,28 @@ class LayerNormMLP(TransformerEngineBase):
normalize_acts.append(act.lower()) normalize_acts.append(act.lower())
return tuple(normalize_acts) in geglu_act_pool return tuple(normalize_acts) in geglu_act_pool
use_fused_ln_mlp = fuse_layernorm \ def is_gelu(acts):
geglu_act_pool = [('gelu',)]
normalize_acts = []
for act in acts:
if not isinstance(act, str):
return False
normalize_acts.append(act.lower())
return tuple(normalize_acts) in geglu_act_pool
use_fused_ln_geglu_mlp = fuse_layernorm \
and (not self.use_bias) and is_geglu(self.activations) \ and (not self.use_bias) and is_geglu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3) and (self.intermediate_dropout_rate < 1e-3)
use_fused_ln_gelu_mlp = fuse_layernorm \
and self.use_bias and is_gelu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3)
# LayerNorm # LayerNorm
if self.enable_layernorm: if self.enable_layernorm:
assert self.axis == -1 # Only support axis == -1 at this moment assert self.axis == -1 # Only support axis == -1 at this moment
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
features = inputs.shape[-1] features = inputs.shape[-1]
...@@ -883,7 +930,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -883,7 +930,10 @@ class LayerNormMLP(TransformerEngineBase):
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape) kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
if use_fused_ln_mlp: ffn1_ckpt_name = 'ffn1'
ffn2_ckpt_name = 'ffn2'
if use_fused_ln_geglu_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment assert self.axis == -1 # Only support axis = =-1 at this moment
out = layernorm_geglu_fp8_mlp(y, out = layernorm_geglu_fp8_mlp(y,
...@@ -892,8 +942,41 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -892,8 +942,41 @@ class LayerNormMLP(TransformerEngineBase):
fp8_meta_package, fp8_meta_package,
self.layernorm_type, self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon) epsilon=self.epsilon,
else: # not use_fused_ln_mlp layernorm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name)
elif use_fused_ln_gelu_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
bias_1 = nn_partitioning.param_with_axes('wi_bias',
self.bias_init,
intermediate_dim,
jnp.float32,
axes=self.bias_axes_1)
bias_1 = bias_1.astype(self.dtype)
bias_2 = nn_partitioning.param_with_axes('wo_bias',
self.bias_init, (hidden_size,),
jnp.float32,
axes=self.bias_axes_2)
bias_2 = bias_2.astype(self.dtype)
out = layernorm_gelu_fp8_mlp(y,
scale,
ln_bias, [kernel_1, kernel_2], [bias_1, bias_2],
fp8_meta_package,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name)
else: # not use_fused_ln_geglu_mlp
# DenseGeneral 1 # DenseGeneral 1
gemm1_fp8_meta_package = None if fp8_meta_package is None \ gemm1_fp8_meta_package = None if fp8_meta_package is None \
...@@ -906,8 +989,11 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -906,8 +989,11 @@ class LayerNormMLP(TransformerEngineBase):
gemm1_fp8_meta_package, gemm1_fp8_meta_package,
self.layernorm_type, self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon) epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_1_input_axes)
else: else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
x = type_safe_dot_general(y, x = type_safe_dot_general(y,
kernel_1, kernel_1,
fp8_meta_pkg=gemm1_fp8_meta_package, fp8_meta_pkg=gemm1_fp8_meta_package,
...@@ -924,11 +1010,14 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -924,11 +1010,14 @@ class LayerNormMLP(TransformerEngineBase):
bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
x += jnp.reshape(bias, bias_shape) x += jnp.reshape(bias, bias_shape)
x = checkpoint_name(x, 'ffn1') x = checkpoint_name(x, ffn1_ckpt_name)
activations = [] activations = []
if is_geglu(self.activations): if is_geglu(self.activations):
z = geglu(x) z = geglu(x)
elif is_gelu(self.activations):
z = gelu(x)
z = jnp.reshape(z, (*z.shape[:-2], -1))
else: else:
x = jnp.split(x, num_activations, axis=-2) x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(self.activations): for idx, act_fn in enumerate(self.activations):
...@@ -942,6 +1031,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -942,6 +1031,8 @@ class LayerNormMLP(TransformerEngineBase):
rng_collection=self.intermediate_dropout_rng_name)( rng_collection=self.intermediate_dropout_rng_name)(
z, deterministic=deterministic) z, deterministic=deterministic)
z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
# DenseGeneral 2 # DenseGeneral 2
gemm2_fp8_meta_package = None if fp8_meta_package is None \ gemm2_fp8_meta_package = None if fp8_meta_package is None \
else fp8_meta_package.get_package_by_gemm_idx(1) else fp8_meta_package.get_package_by_gemm_idx(1)
...@@ -960,6 +1051,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -960,6 +1051,6 @@ class LayerNormMLP(TransformerEngineBase):
bias = bias.astype(self.dtype) bias = bias.astype(self.dtype)
out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,)) out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,))
out = checkpoint_name(out, 'ffn2') out = checkpoint_name(out, ffn2_ckpt_name)
return out, ln_output # Output, layner_norm_output return out, ln_output # Output, layner_norm_output
This diff is collapsed.
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
"""JAX layernorm modules""" """JAX layernorm modules"""
from functools import partial from functools import partial
from typing import Tuple
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -12,6 +14,7 @@ from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd ...@@ -12,6 +14,7 @@ from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd
from .dot import fp8_dot_impl, get_precision_of_fp8_dot from .dot import fp8_dot_impl, get_precision_of_fp8_dot
from .fp8 import FP8Helper, FP8MetaPackage from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes
def canonicalize_layernorm_type(x): def canonicalize_layernorm_type(x):
...@@ -96,14 +99,20 @@ def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz): ...@@ -96,14 +99,20 @@ def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz):
_layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule) _layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule)
def layernorm_fp8_dot(x: jnp.ndarray, def layernorm_fp8_dot(
kernel: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, kernel: jnp.ndarray,
beta: jnp.ndarray, gamma: jnp.ndarray,
fp8_meta_pkg: FP8MetaPackage, beta: jnp.ndarray,
layernorm_type: str, fp8_meta_pkg: FP8MetaPackage,
zero_centered_gamma: bool = False, layernorm_type: str,
epsilon: float = 1e-6) -> jnp.ndarray: zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
layernorm_input_axes: Tuple[
str, ...] = None, # The logic axes of sharding constraint to the layernorm input.
dot_input_axes: Tuple[str,
...] = None # The logic axes of sharding constraint to the dot input.
) -> jnp.ndarray:
""" """
Layernorm + FP8 GEMM Layernorm + FP8 GEMM
""" """
...@@ -114,18 +123,21 @@ def layernorm_fp8_dot(x: jnp.ndarray, ...@@ -114,18 +123,21 @@ def layernorm_fp8_dot(x: jnp.ndarray,
fwd_dtype = FP8Helper.FWD_DTYPE fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE bwd_dtype = FP8Helper.BWD_DTYPE
output = _layernorm_fp8_dot(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv, output = _layernorm_fp8_dot(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
layernorm_type, fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon) layernorm_type, fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon,
layernorm_input_axes, dot_input_axes)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12)) @partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14))
def _layernorm_fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, def _layernorm_fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
fp8_max: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, layernorm_type: str, fwd_dtype: jnp.dtype, scale_inv: jnp.ndarray, layernorm_type: str, fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype, zero_centered_gamma: bool, epsilon: float): bwd_dtype: jnp.dtype, zero_centered_gamma: bool, epsilon: float,
layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...]):
output, _ = _layernorm_fp8_dot_fwd_rule(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv, output, _ = _layernorm_fp8_dot_fwd_rule(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
layernorm_type, fwd_dtype, bwd_dtype, layernorm_type, fwd_dtype, bwd_dtype,
zero_centered_gamma, epsilon) zero_centered_gamma, epsilon, layernorm_input_axes,
dot_input_axes)
return output return output
...@@ -142,7 +154,9 @@ def _layernorm_fp8_dot_fwd_rule( ...@@ -142,7 +154,9 @@ def _layernorm_fp8_dot_fwd_rule(
fwd_dtype, fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument bwd_dtype, # pylint: disable=unused-argument
zero_centered_gamma, zero_centered_gamma,
epsilon): epsilon,
layernorm_input_axes,
dot_input_axes):
x_contracting_dims = (len(x.shape) - 1,) x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,) k_contracting_dims = (0,)
...@@ -156,6 +170,8 @@ def _layernorm_fp8_dot_fwd_rule( ...@@ -156,6 +170,8 @@ def _layernorm_fp8_dot_fwd_rule(
x_scale = scale[gemm_x_idx] x_scale = scale[gemm_x_idx]
x_scale_inv = scale_inv[gemm_x_idx] x_scale_inv = scale_inv[gemm_x_idx]
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
if layernorm_type == 'layernorm': if layernorm_type == 'layernorm':
ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8( ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
x, x,
...@@ -191,6 +207,8 @@ def _layernorm_fp8_dot_fwd_rule( ...@@ -191,6 +207,8 @@ def _layernorm_fp8_dot_fwd_rule(
casted_kernel, updated_kernel_amax = \ casted_kernel, updated_kernel_amax = \
cast_fp8(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype) cast_fp8(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype)
ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_input_axes)
# (batch..., hidden_in) x (hidden_in, hidden_out...) # (batch..., hidden_in) x (hidden_in, hidden_out...)
output = fp8_dot_impl(ln_out, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype, output = fp8_dot_impl(ln_out, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype,
(x_contracting_dims, k_contracting_dims), (x_contracting_dims, k_contracting_dims),
...@@ -209,6 +227,8 @@ def _layernorm_fp8_dot_bwd_rule( ...@@ -209,6 +227,8 @@ def _layernorm_fp8_dot_bwd_rule(
bwd_dtype, bwd_dtype,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument
ctx, ctx,
grad): grad):
ln_out_, casted_kernel, fp8_max, amax, scale, scale_inv, \ ln_out_, casted_kernel, fp8_max, amax, scale, scale_inv, \
...@@ -243,6 +263,7 @@ def _layernorm_fp8_dot_bwd_rule( ...@@ -243,6 +263,7 @@ def _layernorm_fp8_dot_bwd_rule(
(g_for_dgrad_constracting_dim, k_constracting_dim), (g_for_dgrad_constracting_dim, k_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
if layernorm_type == 'layernorm': if layernorm_type == 'layernorm':
dx, dgamma, dbeta = layernorm_bwd(dgrad, dx, dgamma, dbeta = layernorm_bwd(dgrad,
x, x,
......
This diff is collapsed.
...@@ -77,6 +77,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -77,6 +77,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
attn_mask_type: str = 'causal' attn_mask_type: str = 'causal'
fuse_qkv: bool = True fuse_qkv: bool = True
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
float32_logits: bool = False float32_logits: bool = False
...@@ -109,6 +110,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -109,6 +110,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
fuse_qkv=self.fuse_qkv, fuse_qkv=self.fuse_qkv,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init, scaled_query_init=self.scaled_query_init,
float32_logits=self.float32_logits) float32_logits=self.float32_logits)
...@@ -156,11 +158,14 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -156,11 +158,14 @@ class TransformerLayer(TransformerEngineBaseLayer):
float32_attention_logits: bool = False float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = 'causal' self_attn_mask_type: str = 'causal'
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
enable_relative_embedding: bool = True enable_relative_embedding: bool = True
relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None) relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
drop_path: float = 0.0 drop_path: float = 0.0
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
...@@ -221,11 +226,14 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -221,11 +226,14 @@ class TransformerLayer(TransformerEngineBaseLayer):
float32_attention_logits=self.float32_attention_logits, float32_attention_logits=self.float32_attention_logits,
layer_type=self.layer_type, layer_type=self.layer_type,
self_attn_mask_type=self.self_attn_mask_type, self_attn_mask_type=self.self_attn_mask_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
enable_relative_embedding=self.enable_relative_embedding, enable_relative_embedding=self.enable_relative_embedding,
relative_embedding=relative_embedding_flax_module, relative_embedding=relative_embedding_flax_module,
drop_path=self.drop_path, drop_path=self.drop_path,
fuse_qkv_params=self.fuse_qkv_params, fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init) scaled_query_init=self.scaled_query_init)
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
""" """
Sharding Meta for xmap with CustomCall Sharding Meta for xmap with CustomCall
""" """
import os
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
...@@ -16,6 +16,19 @@ from jax.sharding import PartitionSpec ...@@ -16,6 +16,19 @@ from jax.sharding import PartitionSpec
_PXLA_THREAD_RESOURCES = pxla.thread_resources _PXLA_THREAD_RESOURCES = pxla.thread_resources
# Axis Names
BATCH_AXES = 'nvte_batch'
SEQLEN_AXES = 'nvte_seqlen'
SEQLEN_TP_AXES = 'nvte_seqlen_tp'
HEAD_AXES = 'nvte_head'
HIDDEN_AXES = 'nvte_hidden'
HIDDEN_TP_AXES = 'nvte_hidden_tp'
JOINED_AXES = 'nvte_joined'
W_NO_SHARD_AXES = 'nvte_w_no_shard'
W_FSDP_AXES = 'nvte_w_fsdp'
W_TP_AXES = 'nvte_w_tp'
W_JOINED_AXES = 'nvte_w_joined'
def _get_mesh_info(resource: str): def _get_mesh_info(resource: str):
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
...@@ -24,6 +37,81 @@ def _get_mesh_info(resource: str): ...@@ -24,6 +37,81 @@ def _get_mesh_info(resource: str):
return mesh.shape[resource], resource return mesh.shape[resource], resource
def get_sharding_map_logic_axis_to_mesh_axis():
"""
Generate a dict to map logical axes to mesh axes.
"""
gsr = global_mesh_resource()
IS_FSDP_OUTER = bool(int(os.environ.get("NVTE_OUTER_BATCH_FSDP_DIM", False)))
batch_resources = [gsr.fsdp_resource, gsr.dp_resource] if IS_FSDP_OUTER \
else [gsr.dp_resource, gsr.fsdp_resource]
batch_dim_rule = []
for resource in batch_resources:
if resource is not None and resource not in batch_dim_rule:
batch_dim_rule.append(resource)
if len(batch_dim_rule) <= 0:
batch_dim_rule = None
elif len(batch_dim_rule) == 1:
batch_dim_rule = batch_dim_rule[0]
else:
batch_dim_rule = tuple(batch_dim_rule)
te_logical_axis_to_mesh_axis = {
BATCH_AXES: batch_dim_rule,
SEQLEN_AXES: None,
SEQLEN_TP_AXES: gsr.tp_resource,
HEAD_AXES: gsr.tp_resource,
HIDDEN_AXES: None,
HIDDEN_TP_AXES: gsr.tp_resource,
JOINED_AXES: None,
W_NO_SHARD_AXES: None,
W_FSDP_AXES: gsr.fsdp_resource,
W_TP_AXES: gsr.tp_resource,
W_JOINED_AXES: None,
}
return te_logical_axis_to_mesh_axis
def generate_pspec(logical_axis_names):
"""
Convert logical axes to PartitionSpec
"""
rules = get_sharding_map_logic_axis_to_mesh_axis()
mesh_axis_names = [rules[name] for name in logical_axis_names]
pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
return pspec
def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
"""
A wrapper function to jax.lax.with_sharding_constraint to
support the case that Mesh is empty.
"""
if pspec is None:
return x
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh.empty:
return x
return jax.lax.with_sharding_constraint(x, pspec)
def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: tuple | list):
"""
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
"""
if logical_axis_names is None:
return x
assert len(x.shape) == len(logical_axis_names)
pspec = generate_pspec(logical_axis_names)
return with_sharding_constraint(x, pspec)
def get_all_mesh_axes(): def get_all_mesh_axes():
""" """
Get all name of mesh axes Get all name of mesh axes
...@@ -42,17 +130,6 @@ def get_padded_spec(spec, ndim): ...@@ -42,17 +130,6 @@ def get_padded_spec(spec, ndim):
return spec + (None,) * (ndim - len(spec)) return spec + (None,) * (ndim - len(spec))
def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
"""
A wrapper function to jax.lax.with_sharding_constraint to
support the case that Mesh is empty.
"""
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh.empty:
return x
return jax.lax.with_sharding_constraint(x, pspec)
def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str): def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str):
""" """
A wrapper function to invoke lax.p* operations, like psum. A wrapper function to invoke lax.p* operations, like psum.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment